mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
Compare commits
155 Commits
chore/depl
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d137ad31f | |||
| 759b4cbad3 | |||
| 72c92fa60a | |||
| 1ae98b3ea4 | |||
| 196c040c99 | |||
| fad5656b2e | |||
| 76fb1b6ea8 | |||
| 157ba6f5a0 | |||
| 1c0080be6f | |||
| 6b12152ce8 | |||
| 1231c2f976 | |||
| 00ac937934 | |||
| 2c323104eb | |||
| edeaac5d4e | |||
| d16a012575 | |||
| 23cd129802 | |||
| e40b30d746 | |||
| a1d9340a62 | |||
| 3addc1e386 | |||
| 24876bb05d | |||
| 0cdd478f25 | |||
| 0db9714eb6 | |||
| 9da4d167fa | |||
| a1ad4be61e | |||
| 8cb2cffbf7 | |||
| a8f009a965 | |||
| 0bfbd2061e | |||
| c8abb11bf0 | |||
| f9320b2c91 | |||
| f0fd7ddb60 | |||
| b77f5f1e4a | |||
| b67c3a5f76 | |||
| 5b5a06136a | |||
| 6e3c9597ff | |||
| 3c98f96ae8 | |||
| 44725dde74 | |||
| d3058d63bd | |||
| 4fc62d3b38 | |||
| e14cb209a4 | |||
| bb3c9929f9 | |||
| 35a55813d2 | |||
| a247d625e5 | |||
| 5c7f05bd10 | |||
| 02e1a60cde | |||
| 57b573d02b | |||
| 9de40e8f21 | |||
| cad0942f4d | |||
| cb9b1b593e | |||
| 2a8bdc2373 | |||
| ee6a07d13c | |||
| 2d6c9300e3 | |||
| d6b4c800c2 | |||
| 1b37635f92 | |||
| 86af36429d | |||
| b96ea94505 | |||
| d649cccda0 | |||
| 5cbbd78f38 | |||
| 5a0ad4ecd9 | |||
| 1e76b9e1b8 | |||
| 1b972c4e09 | |||
| 7968d2c3c8 | |||
| 7507e9ba67 | |||
| ca31762e26 | |||
| f591da7865 | |||
| f19679b217 | |||
| b682591c7a | |||
| 8f6b59feff | |||
| 99833f65d8 | |||
| 696fc5c213 | |||
| eae44cfecb | |||
| dea4e66456 | |||
| 3cd0da303a | |||
| 888483a2f8 | |||
| 7056985f72 | |||
| 6ce61eae59 | |||
| 079af312c6 | |||
| 0da13dfe4d | |||
| 1ff4d75084 | |||
| e35d23c3cb | |||
| e530e84772 | |||
| 2257a4f1ef | |||
| f465dc5090 | |||
| 5c1cfe6ada | |||
| 8d401d84c7 | |||
| b74287c2ab | |||
| c64d3e98c4 | |||
| a3265f722e | |||
| 5658065b97 | |||
| 8fc2807194 | |||
| fc7716704d | |||
| 71ffaacb58 | |||
| cfc1cf2b8c | |||
| 055d9b9f0a | |||
| 21711bebeb | |||
| becccbf288 | |||
| 86497045c9 | |||
| 687a177b24 | |||
| 4a6d278354 | |||
| 7d69302e9f | |||
| bcd573e560 | |||
| 07c0c4e7b1 | |||
| a8a2ca7b98 | |||
| de47d43b65 | |||
| 240912cef5 | |||
| 72e040ead3 | |||
| c0ee821d45 | |||
| c7c3296572 | |||
| e7be04fd58 | |||
| df6b5be50a | |||
| 8e5f09091b | |||
| 0a3005701f | |||
| d8571ce965 | |||
| f241ae25be | |||
| c6474a2a8b | |||
| 480d05bc48 | |||
| f75725ccd9 | |||
| 2fe8c48255 | |||
| ec5404cc9d | |||
| 20f62b9919 | |||
| 04f5555580 | |||
| 129af96c23 | |||
| df40960f5d | |||
| 599960024d | |||
| 6805d9bfc0 | |||
| 928f888ef5 | |||
| f46c03460e | |||
| 0b60338ad5 | |||
| 91ac465982 | |||
| 9490d63c50 | |||
| ae538ced47 | |||
| 487249728b | |||
| 372a2e3e9c | |||
| 4939a9c33d | |||
| b6f92f1dc4 | |||
| ce276573a8 | |||
| 5070cc9668 | |||
| a392a72960 | |||
| 30270b5c30 | |||
| 24715a9570 | |||
| c530a5d272 | |||
| 418ee7398e | |||
| 78f40c0d25 | |||
| 2cc567c6a3 | |||
| a180ab19e4 | |||
| 13eaa436e7 | |||
| 3596d12e4c | |||
| e8de10a3b5 | |||
| f5ab5e7eb3 | |||
| 0c40e1c2a0 | |||
| c29d76757e | |||
| 91c1d3ad81 | |||
| 57b02e341c | |||
| b94ff65e9f | |||
| 678260e34e | |||
| 739e34d08a |
@ -1,73 +1,93 @@
|
||||
---
|
||||
name: frontend-code-review
|
||||
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
|
||||
description: Review Dify frontend code for correctness, accessibility, component design, dify-ui usage, data/query boundaries, performance, and tests. Trigger for `.tsx`, `.ts`, `.js`, UI, React, Next.js, pending-change, or focused frontend review requests.
|
||||
---
|
||||
|
||||
# Frontend Code Review
|
||||
|
||||
## Intent
|
||||
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
|
||||
## When To Use
|
||||
|
||||
1. **Pending-change review** – inspect staged/working-tree files slated for commit and flag checklist violations before submission.
|
||||
2. **File-targeted review** – review the specific file(s) the user names and report the relevant checklist findings.
|
||||
Use this skill when the user asks to review, audit, analyze, or sanity-check frontend code under `web/`, `packages/dify-ui/`, or frontend-adjacent TypeScript files.
|
||||
|
||||
Stick to the checklist below for every applicable file and mode.
|
||||
Supported modes:
|
||||
|
||||
## Checklist
|
||||
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
|
||||
- **Pending-change review**: inspect staged and working-tree changes.
|
||||
- **File-focused review**: inspect explicitly named files or paths.
|
||||
- **Diff/snippet review**: review pasted diffs or snippets using best-effort references.
|
||||
|
||||
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
|
||||
Do not use this skill for backend-only code under `api/`; use `backend-code-review` instead.
|
||||
|
||||
## Required Context
|
||||
|
||||
Before reviewing, read the relevant local contracts:
|
||||
|
||||
- `web/AGENTS.md` for Dify frontend workflow, overlays, design tokens, state, and tests.
|
||||
- `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md` when code uses or changes `@langgenius/dify-ui/*`.
|
||||
- `web/docs/overlay.md` when reviewing dialogs, drawers, popovers, tooltips, menus, selects, comboboxes, or other floating UI.
|
||||
- `web/docs/test.md` and the `frontend-testing` skill when reviewing tests or testability.
|
||||
- `how-to-write-component` when reviewing React component structure, ownership, effects, query/mutation contracts, or memoization.
|
||||
|
||||
For any UI, UX, or accessibility review, fetch the latest Web Interface Guidelines before finalizing findings. Treat them as a required baseline, not the complete source of accessibility truth:
|
||||
|
||||
```text
|
||||
https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md
|
||||
```
|
||||
|
||||
If the review depends on a current framework, SDK, browser API, or accessibility behavior and local code does not settle it, check the current official docs first. For browser compatibility, deprecation, or behavior-sensitive frontend APIs, verify MDN or the relevant standard.
|
||||
|
||||
## Rule Packs
|
||||
|
||||
Apply every relevant rule pack:
|
||||
|
||||
- [references/accessibility-ui.md](references/accessibility-ui.md) — accessibility, semantic HTML, focus, forms, keyboard, disabled states, copy, and long-content behavior. Combines Web Interface Guidelines with Dify UI, Base UI, MDN, and local primitive contracts.
|
||||
- [references/dify-ui.md](references/dify-ui.md) — Dify UI primitive usage, Base UI semantics, overlays, forms, tokens, radius mapping, and primitive boundaries.
|
||||
- [references/component-architecture.md](references/component-architecture.md) — component ownership, props, state, effects, exports, wrappers, and feature organization.
|
||||
- [references/data-query-contracts.md](references/data-query-contracts.md) — generated contracts, TanStack Query, mutations, workspace/auth/SSR boundaries, URL/local storage state.
|
||||
- [references/performance.md](references/performance.md) — React/Next performance review rules from Vercel guidance, scoped to real risk.
|
||||
- [references/testing.md](references/testing.md) — frontend test review rules.
|
||||
- [references/dify-invariants.md](references/dify-invariants.md) — stable Dify-specific runtime invariants that generic React/a11y rules will not catch.
|
||||
- [references/code-quality.md](references/code-quality.md) — general TypeScript, styling, naming, and maintainability rules.
|
||||
|
||||
## Review Process
|
||||
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
|
||||
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
|
||||
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
|
||||
|
||||
## Required output
|
||||
When invoked, the response must exactly follow one of the two templates:
|
||||
1. Identify the review scope. For pending changes, inspect `git diff --stat`, `git diff`, and staged diff if relevant. For file-focused reviews, stay within the named files unless a referenced owner/contract must be read.
|
||||
2. Read code around the changed lines and the owning module. Do not review by isolated snippets when nearby ownership, labels, query inputs, or overlay structure decide correctness.
|
||||
3. Check user-visible regressions first: accessibility, broken interaction, auth/permission leaks, query/hydration errors, data loss, navigation mistakes, and impossible states.
|
||||
4. Then check maintainability and performance: ownership, effects, wrappers, memoization, bundle/waterfall risks, tests, and design-system drift.
|
||||
5. Report only actionable findings. Do not list speculative risks, style preferences, or broad refactors unless they are directly tied to a reproducible issue in scope.
|
||||
|
||||
### Template A (any findings)
|
||||
```
|
||||
# Code review
|
||||
Found <N> urgent issues need to be fixed:
|
||||
## Severity
|
||||
|
||||
## 1 <brief description of bug>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
- **P0**: security/privacy/auth leak, data loss, production crash, inaccessible critical flow, or broken primary workflow.
|
||||
- **P1**: user-visible regression, hydration/SSR failure, invalid API/query contract, broken keyboard/focus behavior, or serious design-system/a11y violation.
|
||||
- **P2**: maintainability or performance issue likely to cause bugs, duplicated state, incorrect ownership, missing tests for risky behavior, or non-critical a11y issue.
|
||||
- **P3**: minor cleanup with clear value. Omit unless the user asked for a thorough audit.
|
||||
|
||||
## Output Format
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
Lead with findings, ordered by severity. Use this structure:
|
||||
|
||||
---
|
||||
... (repeat for each urgent issue) ...
|
||||
```markdown
|
||||
## Findings
|
||||
|
||||
Found <M> suggestions for improvement:
|
||||
- [P1] Short issue title
|
||||
File: `path/to/file.tsx:123`
|
||||
Why it matters and how to reproduce or reason about it.
|
||||
Suggested fix: concrete fix direction.
|
||||
|
||||
## 1 <brief description of suggestion>
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
## Open Questions
|
||||
|
||||
- Question or assumption, if any.
|
||||
|
||||
### Suggested fix
|
||||
<brief description of suggested fix>
|
||||
## Summary
|
||||
|
||||
---
|
||||
|
||||
... (repeat for each suggestion) ...
|
||||
Brief secondary context. Mention tests not run or residual risk.
|
||||
```
|
||||
|
||||
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
|
||||
|
||||
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
|
||||
|
||||
Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
|
||||
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
```
|
||||
## Code review
|
||||
No issues found.
|
||||
```
|
||||
Rules:
|
||||
|
||||
- If there are no findings, say `No issues found.` and mention any test gaps or residual risk.
|
||||
- Always include file and line when available.
|
||||
- Keep findings concrete and reproducible.
|
||||
- Do not include praise sections by default.
|
||||
- Do not ask to apply fixes unless the user explicitly wants review plus implementation.
|
||||
|
||||
@ -0,0 +1,109 @@
|
||||
# Accessibility And UI Rules
|
||||
|
||||
Accessibility findings are first-class review findings. Treat broken keyboard access, missing accessible names, focus loss, and unreachable popup content as correctness bugs, not polish.
|
||||
|
||||
Before finalizing UI or accessibility findings, fetch the latest Web Interface Guidelines as a required baseline:
|
||||
|
||||
```text
|
||||
https://raw.githubusercontent.com/vercel-labs/web-interface-guidelines/main/command.md
|
||||
```
|
||||
|
||||
Do not treat that document as the complete accessibility rule set. Combine it with:
|
||||
|
||||
- `packages/dify-ui/README.md`, `packages/dify-ui/AGENTS.md`, and the relevant primitive implementation when code uses `@langgenius/dify-ui/*`.
|
||||
- Base UI docs and local `.d.ts` contracts when primitive semantics, focus target, labels, or popup reachability are unclear.
|
||||
- MDN or relevant WAI-ARIA/browser standards when behavior, compatibility, or deprecation status matters.
|
||||
- The current feature's product semantics, because an accessible primitive can still be used in an inaccessible workflow.
|
||||
|
||||
## Semantic HTML
|
||||
|
||||
Flag:
|
||||
|
||||
- Clickable `div` or `span` used for actions.
|
||||
- Router navigation implemented with button or `onClick` when a `Link` / `<a>` is the real semantic element.
|
||||
- Icon-only buttons without `aria-label` or `aria-labelledby`.
|
||||
- Decorative icons missing `aria-hidden="true"`.
|
||||
- Images without `alt`; use `alt=""` only when truly decorative.
|
||||
- Heading levels that skip hierarchy in page-level content.
|
||||
|
||||
Prefer semantic HTML before ARIA.
|
||||
|
||||
## Keyboard And Focus
|
||||
|
||||
Flag:
|
||||
|
||||
- Interactive elements without visible `focus-visible` treatment.
|
||||
- `outline-none` / `outline-hidden` without an equivalent focus-visible ring or state.
|
||||
- Custom interactive elements missing keyboard handling.
|
||||
- Focus trapped, lost, or sent to the wrong surface after dialog/popover/menu close.
|
||||
- Focus ring applied to the wrong DOM node. Verify the actual focus target, especially with Base UI controls such as Slider.
|
||||
|
||||
Use `focus-visible` for keyboard focus. Use `focus-within` or `has-[:focus-visible]` when the visual wrapper is not the focused element.
|
||||
|
||||
## Forms
|
||||
|
||||
Flag:
|
||||
|
||||
- Inputs, selects, switches, checkboxes, radios, comboboxes, or sliders without a label relationship.
|
||||
- Missing stable `name` on form fields that submit or validate.
|
||||
- Incorrect input `type`, `inputMode`, `autoComplete`, or `spellCheck` for email, token, URL, number, search, code, or username fields.
|
||||
- Labels that are not clickable.
|
||||
- Submit buttons disabled before a request starts, preventing normal submit behavior.
|
||||
- Non-submit buttons inside forms missing `type="button"`.
|
||||
- Errors not associated with fields or not reachable by screen readers.
|
||||
- Error recovery that does not focus or expose the first invalid field.
|
||||
- `onPaste` blocking paste.
|
||||
- Placeholder text used as the only label.
|
||||
- Password managers accidentally triggered on non-auth fields because autocomplete is missing or wrong.
|
||||
|
||||
Prefer visible labels. If visible surrounding text already labels the control, use a visually hidden label or a precise `aria-label`.
|
||||
|
||||
## Disabled, Loading, And Async States
|
||||
|
||||
Flag:
|
||||
|
||||
- Loading state without `aria-busy`, `role="status"`, or another accessible update path when it changes user interaction.
|
||||
- Spinner or decorative loading icon exposed to screen readers.
|
||||
- Disabled controls that hide the reason users cannot proceed.
|
||||
- `aria-disabled` used without manually blocking click, Space, and Enter.
|
||||
- Toasts, inline validation, or async status changes that are not announced when users need the update to continue.
|
||||
- Icon-only loading/error affordances without text or accessible status where the state matters.
|
||||
|
||||
Use native `disabled` when the control must not be interactive. Use `aria-disabled` only when the element must remain focusable and the code handles all blocked interactions.
|
||||
|
||||
For repeated shared disabled reasons, prefer a visible group message or badge plus native disabled controls. Use per-control popover/info only when the reason is item-specific.
|
||||
|
||||
## Overlays And Popup Reachability
|
||||
|
||||
Flag:
|
||||
|
||||
- Tooltip used for long, structured, interactive, or unique information.
|
||||
- Tooltip content required to understand or complete a flow.
|
||||
- PreviewCard content that touch or screen-reader users cannot reach through the trigger's click destination.
|
||||
- Popover/dialog/menu triggers without accessible names.
|
||||
- Popup content without title/description where the primitive requires them.
|
||||
|
||||
Use Popover for explanatory content, rich help, and infotips. Use Tooltip only as a short visual label for a trigger that already has an accessible name.
|
||||
|
||||
## Long Content And Layout
|
||||
|
||||
Flag:
|
||||
|
||||
- Text in flex/grid children without `min-w-0` when it can overflow.
|
||||
- Names, labels, file names, model names, workspace names, or user content lacking `truncate`, `line-clamp`, or `break-words`.
|
||||
- Right-side icons, badges, checks, or actions that shrink before the text area.
|
||||
- Empty arrays or empty strings rendering broken layout instead of an empty state.
|
||||
- Button, tab, badge, chip, menu item, or card text that can overlap sibling controls at common viewport widths.
|
||||
|
||||
The usual Dify layout chain is: container has width constraints, text region uses `min-w-0 flex-1 truncate`, adornments use `shrink-0`.
|
||||
|
||||
## Motion, Images, And Copy
|
||||
|
||||
Flag:
|
||||
|
||||
- `transition-all`.
|
||||
- Animations that do not respect reduced motion.
|
||||
- Layout-affecting animation where transform/opacity would work.
|
||||
- Images without dimensions.
|
||||
- Loading copy using `...` instead of `…`.
|
||||
- Hardcoded dates, times, numbers, or currency formats instead of `Intl.*`.
|
||||
@ -1,15 +0,0 @@
|
||||
# Rule Catalog — Business Logic
|
||||
|
||||
## Can't use workflowStore in Node components
|
||||
|
||||
IsUrgent: True
|
||||
|
||||
### Description
|
||||
|
||||
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
|
||||
|
||||
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
|
||||
|
||||
### Suggested Fix
|
||||
|
||||
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.
|
||||
@ -1,44 +1,66 @@
|
||||
# Rule Catalog — Code Quality
|
||||
# Code Quality Rules
|
||||
|
||||
## Conditional class names use utility function
|
||||
## Scope Control
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
Flag changes that expand beyond the requested feature or review scope:
|
||||
|
||||
### Description
|
||||
- Repo-wide cleanup mixed into a targeted fix.
|
||||
- Compatibility exports, aliases, shims, or wrapper layers added without an explicit migration requirement.
|
||||
- Shared abstractions created before there is stable cross-feature reuse.
|
||||
- Business components moved into generic shared locations without a clear ownership boundary.
|
||||
|
||||
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
|
||||
## TypeScript
|
||||
|
||||
### Suggested Fix
|
||||
Flag:
|
||||
|
||||
```ts
|
||||
import { cn } from '@/utils/classnames'
|
||||
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
|
||||
```
|
||||
- `any` or broad `Record<string, any>` where generated/API types or local domain types exist.
|
||||
- Re-declared API shapes instead of importing generated or returned types.
|
||||
- Weak route/query param typing that leaks `string | string[] | undefined` deep into components.
|
||||
- Runtime wrappers added only to satisfy TypeScript when a narrower type boundary would preserve the existing runtime shape.
|
||||
|
||||
## Tailwind-first styling
|
||||
Prefer:
|
||||
|
||||
IsUrgent: True
|
||||
Category: Code Quality
|
||||
- Explicit domain names that match the API contract.
|
||||
- Type narrowing at route/API boundaries.
|
||||
- Small conversion helpers colocated with the component that needs them.
|
||||
|
||||
### Description
|
||||
## Styling
|
||||
|
||||
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
|
||||
Flag:
|
||||
|
||||
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
|
||||
- New CSS modules or ad hoc CSS when Tailwind utilities and Dify tokens cover the need.
|
||||
- Generic color utilities where Dify semantic tokens exist.
|
||||
- Hardcoded magic class values for colors, spacing, radius, shadow, z-index, or typography when Dify tokens, component variants, or documented radius mappings exist.
|
||||
- `!` important modifiers or important CSS overrides without a narrow, documented reason.
|
||||
- Manual string concatenation for conditional classes.
|
||||
- JS conditional class branches for primitive visual states already exposed by Dify UI/Base UI `data-*` selectors.
|
||||
- Incoming `className` placed before default classes in `cn(...)`, preventing call-site overrides.
|
||||
- Arbitrary z-index or one-off layering fixes on overlays.
|
||||
|
||||
## Classname ordering for easy overrides
|
||||
Use:
|
||||
|
||||
### Description
|
||||
- `cn(...)` from the local package or utility already used by the file.
|
||||
- Dify semantic tokens and Tailwind v4 utilities.
|
||||
- Existing component variants before one-off class forks.
|
||||
- Primitive selectors such as `data-disabled:*`, `data-checked:*`, `data-highlighted:*`, `group-data-*`, `peer-data-*`, and `has-[:focus-visible]` before adding React state or boolean props solely for styling.
|
||||
- Component-level variants, semantic tokens, and normal cascade/order before `!` overrides. Use `!` only for a contained compatibility override that cannot be expressed through the component API or local selector structure.
|
||||
|
||||
When writing components, always place the incoming `className` prop after the component’s own class values so that downstream consumers can override or extend the styling. This keeps your component’s defaults but still lets external callers change or remove specific styles.
|
||||
## Imports
|
||||
|
||||
Example:
|
||||
Flag:
|
||||
|
||||
```tsx
|
||||
import { cn } from '@/utils/classnames'
|
||||
- Barrel imports from `@langgenius/dify-ui`; consumers must use subpath exports.
|
||||
- New overlay imports from legacy `@/app/components/base/modal`, `dialog`, or `drawer`.
|
||||
- Cross-feature imports that bypass explicit top-level public files.
|
||||
- Direct imports from generated/internal implementation files when a feature contract already exposes the intended surface.
|
||||
|
||||
const Button = ({ className }) => {
|
||||
return <div className={cn('bg-primary-600', className)}></div>
|
||||
}
|
||||
```
|
||||
## Copy And i18n
|
||||
|
||||
Flag:
|
||||
|
||||
- User-facing hardcoded strings in `web/`.
|
||||
- Translation namespace drift, especially using unrelated module namespaces for local feature copy.
|
||||
- Generic button labels like `Continue` where the action is specific.
|
||||
- Error messages that state only the failure and not the next step.
|
||||
|
||||
Use feature-local translation keys by default. Alias only when crossing namespaces.
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
# Component Architecture Rules
|
||||
|
||||
Use these rules for React component structure, ownership, state, props, effects, and module organization.
|
||||
|
||||
## Ownership
|
||||
|
||||
Flag:
|
||||
|
||||
- State, query, mutation, or handlers hoisted above the lowest component that actually uses them.
|
||||
- Parent components owning row/item actions that do not coordinate a workflow.
|
||||
- Prop drilling through multiple pass-through layers.
|
||||
- A page/tab-level section component becoming the data owner without needing a shared snapshot or shared loading/error/empty UI.
|
||||
- Feature code promoted to shared only because it appears once or might be reused later.
|
||||
|
||||
Accept repeated TanStack Query calls in siblings when each component independently consumes the data. Cache deduplication is not a reason to hoist by itself.
|
||||
|
||||
## Component Boundaries
|
||||
|
||||
Flag:
|
||||
|
||||
- Shallow wrappers that only rename props or hide the real primitive.
|
||||
- Extra DOM wrappers that do not provide layout, semantics, accessibility, state ownership, or library integration.
|
||||
- Dialog/dropdown/popover hidden surfaces that obscure the parent flow when they should be extracted into a small local component.
|
||||
- Business forms, menu bodies, or one-off helpers moved away from their owner without reuse or semantic value.
|
||||
|
||||
Prefer colocated components split by actual data and state needs.
|
||||
|
||||
## Bad Component Design Patterns
|
||||
|
||||
Flag:
|
||||
|
||||
- Components that mix data fetching, mutation side effects, popup state, form validation, layout, and row rendering without a clear owner.
|
||||
- Generic components with many boolean props that encode one feature's workflow.
|
||||
- A shared component that imports feature-specific copy, routes, or API contracts.
|
||||
- A feature component that accepts pre-rendered fragments only to avoid placing ownership correctly.
|
||||
- A child component that receives both raw server data and separately derived flags for the same concept.
|
||||
- A wrapper that changes accessible semantics of the primitive it wraps.
|
||||
- A component that exposes controlled props but still keeps a competing private state for the same value.
|
||||
- A component that cannot render empty, loading, or missing optional API fields without caller-side preprocessing.
|
||||
|
||||
## Props And Types
|
||||
|
||||
Flag:
|
||||
|
||||
- `React.FC` / `FC`.
|
||||
- Default exports outside framework-required files.
|
||||
- Named `Props` types for trivial one-off props where inline typing is clearer.
|
||||
- Props named by UI implementation instead of domain/API role.
|
||||
- API data converted too early or under a generic name that breaks traceability.
|
||||
- Callers duplicating fallback checks that the lowest rendering component already handles.
|
||||
|
||||
Prefer top-level `function` declarations for components and module helpers. Use arrow functions for callbacks and local lambdas.
|
||||
|
||||
## Effects
|
||||
|
||||
Flag effects that:
|
||||
|
||||
- Transform props/state for rendering.
|
||||
- Copy one state value into another representing the same concept.
|
||||
- Handle user actions that belong in event handlers.
|
||||
- Reset state from props when a keyed reset, stable ID, or render-time derivation would work.
|
||||
- Fetch data that belongs in framework APIs or TanStack Query.
|
||||
|
||||
If an effect remains, it must synchronize with a named external system: browser API, subscription, timer, analytics-on-visibility, non-React widget, or imperative DOM integration.
|
||||
|
||||
## State Modeling
|
||||
|
||||
Flag:
|
||||
|
||||
- Storing derived booleans, disabled flags, default tabs, or loading labels that can be calculated from current query/feature state.
|
||||
- Local state used to fake server data or generated contract fields.
|
||||
- UI state persisted to localStorage when it is live app state.
|
||||
- Feature-local mock shells wired to unrelated existing APIs before the real API is confirmed.
|
||||
|
||||
Prefer render-time derivation. Keep true local state for user choices, transient input, controlled popups, and feature UI state that has no server source.
|
||||
|
||||
## Navigation
|
||||
|
||||
Flag:
|
||||
|
||||
- Imperative router navigation for ordinary links.
|
||||
- Button semantics used for navigation.
|
||||
- Navigation state hidden in component state when URL state is required for shareable filters, tabs, or pagination.
|
||||
|
||||
Use `Link` for normal navigation. Use router APIs for mutation success, guarded redirects, command flows, or form submission side effects.
|
||||
@ -0,0 +1,74 @@
|
||||
# Data, Query, And Contract Rules
|
||||
|
||||
Use these rules for generated contracts, TanStack Query, mutations, auth/SSR boundaries, URL state, and client persistence.
|
||||
|
||||
## Generated Contracts
|
||||
|
||||
Flag:
|
||||
|
||||
- New legacy service/helper wrappers around generated `queryOptions()` or `mutationOptions()`.
|
||||
- Continuing to use deprecated contract operations when a ready generated contract exists.
|
||||
- Assuming a generated file means an operation is ready without checking deprecated markers, schema shape, and the actual UI consumer.
|
||||
- Re-declaring API DTOs in components.
|
||||
- Adding compatibility layers instead of migrating the pointed line and deleting the old layer.
|
||||
|
||||
Use `web/contract/*` as the API shape source of truth. Follow existing `{ params, query?, body? }` input shape.
|
||||
|
||||
## Queries
|
||||
|
||||
Flag:
|
||||
|
||||
- `enabled` used to hide missing required input instead of `input: skipToken`.
|
||||
- Fake fallback IDs or placeholder inputs used to force a query to run.
|
||||
- Query results copied into local state for rendering.
|
||||
- Shared query behavior such as invalidation, stale defaults, or retry rules reimplemented at call sites.
|
||||
- `prefetchQuery` treated as a hard gate or as returning data/errors to the caller.
|
||||
|
||||
Use `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))` directly unless a feature hook performs real orchestration.
|
||||
|
||||
## Mutations
|
||||
|
||||
Flag:
|
||||
|
||||
- Deprecated `useInvalid` or `useReset`.
|
||||
- `mutateAsync` used without a need for Promise semantics.
|
||||
- Awaited mutations without `try/catch`.
|
||||
- Components owning shared cache invalidation that belongs in query defaults.
|
||||
- Optimistic updates that do not match current list/detail ownership.
|
||||
|
||||
Use generated `mutationOptions()` directly when possible. Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`.
|
||||
|
||||
## SSR, Auth, And Route Boundaries
|
||||
|
||||
Flag:
|
||||
|
||||
- Request-time auth, setup, workspace role, or tenant decisions moved into static `next.config redirects()`.
|
||||
- Dynamic role gates depending on `workspaces.current` implemented as static path redirects.
|
||||
- Authorization logic depending on soft `prefetchQuery`.
|
||||
- Removing a client fallback before server API unavailable behavior is defined.
|
||||
- Global placeholder query contracts introduced to solve a route-local Suspense issue.
|
||||
- Branding-sensitive UI reading placeholder defaults without checking pending/placeholder state.
|
||||
|
||||
Separate hard gates from soft prefetches. `fetchQuery` can be a server decision boundary; `prefetchQuery` is cache warmup.
|
||||
|
||||
## Workspace And Tenant
|
||||
|
||||
Flag:
|
||||
|
||||
- Treating workspace switch as ordinary CRUD invalidation when the current app flow performs server switch plus full reload.
|
||||
- Query keys that omit workspace/tenant identity when the query truly varies by workspace and no full reload boundary applies.
|
||||
- Mixing `workspace_id` and `tenant_id` without tracing the current backend/API contract.
|
||||
|
||||
Current Dify workspace switch should be reviewed as a tenant cache boundary first.
|
||||
|
||||
## URL State And Local Storage
|
||||
|
||||
Flag:
|
||||
|
||||
- Shareable filters, tabs, pagination, selected panels, or search state hidden only in component state.
|
||||
- One-shot navigation signals modeled as subscribed persistent state.
|
||||
- Live app state stored in localStorage.
|
||||
- Direct `window.localStorage`, `globalThis.localStorage`, or raw storage calls in app code.
|
||||
- High-frequency interaction state persisted on every change instead of on commit/settle.
|
||||
|
||||
Use URL state for shareable UI state, feature/Jotai/store state for live UI state, and `@/hooks/use-local-storage` only for low-frequency client-only preferences, dismissed notices, and UI defaults.
|
||||
@ -0,0 +1,22 @@
|
||||
# Dify Invariants
|
||||
|
||||
Use these stable Dify-specific runtime rules in addition to the generic review packs.
|
||||
|
||||
This file is not a place for active feature notes. Do not add rules for one branch, one PR, or a short-lived product decision such as a specific agent-v2, plugin, model-provider, or onboarding task. Keep a rule here only when all of these are true:
|
||||
|
||||
- It is a stable Dify runtime invariant.
|
||||
- Generic React, TypeScript, accessibility, dify-ui, query, or performance rules would not catch it.
|
||||
- The failure mode is concrete enough to produce a file-line review finding.
|
||||
- The rule is likely to remain valid across normal feature work.
|
||||
|
||||
## Workflow Nodes And RAG Pipe
|
||||
|
||||
Flag:
|
||||
|
||||
- Node components under `web/app/components/workflow/nodes/[nodeName]/node.tsx` importing workflow store hooks that are unavailable in RAG Pipe template rendering.
|
||||
- Node UI relying on provider context that is not mounted in every rendering surface.
|
||||
- Store reads in render where React Flow `useNodes` / `useEdges` provide the actual node/edge source.
|
||||
|
||||
Known failure mode: workflow node components can also render while creating a RAG Pipe from a template. In that context there may be no workflowStore provider, causing a blank screen.
|
||||
|
||||
Prefer React Flow hooks for node/edge UI consumption. Use store APIs only where the provider is guaranteed and the code path is workflow-only.
|
||||
123
.agents/skills/frontend-code-review/references/dify-ui.md
Normal file
123
.agents/skills/frontend-code-review/references/dify-ui.md
Normal file
@ -0,0 +1,123 @@
|
||||
# Dify UI Rules
|
||||
|
||||
Use these rules whenever a review touches `packages/dify-ui/` or code consuming `@langgenius/dify-ui/*`.
|
||||
|
||||
Before finalizing findings for those files, read the current local docs that apply:
|
||||
|
||||
- `packages/dify-ui/README.md`
|
||||
- `packages/dify-ui/AGENTS.md`
|
||||
- `web/docs/overlay.md` for floating UI
|
||||
- `packages/dify-ui/src/<primitive>/index.tsx` for the primitive being changed or consumed
|
||||
|
||||
## Package Boundary
|
||||
|
||||
Flag in `packages/dify-ui`:
|
||||
|
||||
- Imports from `web/`.
|
||||
- Dependencies on Next.js, i18n, ky, Jotai, Zustand, TanStack Query, oRPC, or business APIs.
|
||||
- Business-specific component behavior that belongs in `web/`.
|
||||
- Multiple unrelated primitives in one component folder.
|
||||
|
||||
`packages/dify-ui` is a primitive layer: Base UI headless components + `cva` + `cn` + Dify design tokens.
|
||||
|
||||
## Imports And Exports
|
||||
|
||||
Flag:
|
||||
|
||||
- Consumer imports from `@langgenius/dify-ui` without a subpath.
|
||||
- Missing `package.json#exports` entry for a new primitive.
|
||||
- Internal package imports using workspace subpaths instead of relative paths.
|
||||
- Exported props using internal-only types that consumers cannot import from the component subpath.
|
||||
|
||||
Consumers use subpath exports such as `@langgenius/dify-ui/button`.
|
||||
|
||||
## Props And State
|
||||
|
||||
Flag:
|
||||
|
||||
- Flattened props where related values need a discriminated union, such as `value` / `defaultValue`, `multiple` / `value`, or `clearable` / `onChange`.
|
||||
- React state used only to mirror Base UI state for class names.
|
||||
- JavaScript conditional class logic for visual states that the Dify UI/Base UI primitive already exposes through `data-*` attributes or CSS variables.
|
||||
- Controlled props added when uncontrolled DOM state or CSS variables would be enough.
|
||||
- Thin wrappers that rename Base UI parts without adding semantics.
|
||||
|
||||
Prefer Base UI/Dify UI data attributes and CSS variables for visual state: `data-open`, `data-checked`, `data-disabled`, `data-highlighted`, `data-popup-open`, `group-data-*`, `peer-data-*`, `has-[:focus-visible]`, and primitive CSS variables such as anchor width or transform origin. Use JS conditional classes for product/business state that the primitive does not expose.
|
||||
|
||||
## Forms
|
||||
|
||||
Flag:
|
||||
|
||||
- Form-like UI using unrelated `Input` and `Button` pieces without a submit boundary.
|
||||
- Text-like fields not composed through `FieldRoot`, `FieldLabel`, and `FieldControl` when using Dify UI form semantics.
|
||||
- Select fields using `FieldLabel` instead of `SelectLabel`.
|
||||
- Slider fields using a generic label instead of `SliderLabel`.
|
||||
- Checkbox/radio groups missing `FieldsetRoot` and `FieldsetLegend`.
|
||||
- Field errors or descriptions rendered without `FieldDescription` / `FieldError` relationships.
|
||||
|
||||
`Form` is the submit boundary. Dify UI form primitives are not a form state-management framework; business validation and schema-driven behavior belong in `web/`.
|
||||
|
||||
## Overlay Contract
|
||||
|
||||
Flag:
|
||||
|
||||
- Legacy web overlay imports in new or modified code.
|
||||
- Manual portals around Dify UI overlay primitives.
|
||||
- Call-site `z-*` overrides on overlays.
|
||||
- Missing root `isolation: isolate` assumptions when debugging overlay stacking.
|
||||
- Repeated backdrop, z-index, or portal chrome at call sites.
|
||||
- Tooltip used for infotips, long text, or interactive content.
|
||||
|
||||
All Dify UI body-portalled overlays use `z-50`. Toast uses `z-60`. DOM order handles stacking between overlays.
|
||||
|
||||
## Primitive Selection
|
||||
|
||||
Flag:
|
||||
|
||||
- `Tabs` used for simple mode/filter/view selection where `SegmentedControl` is the semantic primitive.
|
||||
- `SegmentedControl` used where `tablist` / `tabpanel` semantics are required.
|
||||
- `Select` used for searchable or free-form input.
|
||||
- `Combobox` used for unrestricted search text where no selected option is remembered.
|
||||
- `Autocomplete` used for closed-list selection.
|
||||
- Tooltip or PreviewCard used for content that must be reachable on touch or by screen readers.
|
||||
|
||||
Use:
|
||||
|
||||
- `Autocomplete` for free-form text with optional suggestions.
|
||||
- `Combobox` for searchable selected values from a collection.
|
||||
- `Select` for closed, scannable option sets.
|
||||
- `Popover` for infotips, help text, rich content, or interactions.
|
||||
|
||||
## Bad Usage Patterns To Flag
|
||||
|
||||
Flag:
|
||||
|
||||
- Styling a raw Base UI primitive directly in `web/` when a Dify UI primitive exists.
|
||||
- Wrapping a Dify UI primitive in a feature component that hides its label, error, disabled, or focus contract.
|
||||
- Replacing a semantic primitive with a generic `div` plus classes to match a screenshot.
|
||||
- Using `Tooltip` because it is visually convenient when the content is actually help text or needs touch access.
|
||||
- Adding a `z-*` override to make a child popup appear over a parent dialog.
|
||||
- Adding a new app-level wrapper around Dialog, Drawer, Popover, Select, or Combobox that repeats portal/backdrop/positioner logic.
|
||||
- Using dify-ui `Input` as a drop-in replacement for legacy inputs that include search, clear, copy, unit, localized placeholder, or number normalization behavior.
|
||||
- Building a form row from loose text and controls instead of the matching Field/Form primitives.
|
||||
- Adding component state only to style `data-open`, `data-checked`, `data-disabled`, or highlighted states that Base UI already exposes.
|
||||
- Passing booleans down only so children can toggle classes already expressible with primitive `data-*` selectors.
|
||||
|
||||
## Tokens, Radius, And Styling
|
||||
|
||||
Flag:
|
||||
|
||||
- `radius-*` class names.
|
||||
- Custom Tailwind `borderRadius` extension for Figma radius values.
|
||||
- Generic colors where semantic Dify tokens exist.
|
||||
- Hardcoded design values where Dify tokens, component variants, or documented Figma radius mappings exist.
|
||||
- `!` important modifiers used to fight primitive styles instead of fixing the variant, selector, or component composition.
|
||||
- Manual class strings that duplicate primitive variants.
|
||||
- `min-w-(--anchor-width)` on picker popups when it defeats viewport clamping.
|
||||
|
||||
Use the Figma radius mapping from `packages/dify-ui/AGENTS.md`; for example `--radius/sm` maps to `rounded-md`, and `--radius/md` maps to `rounded-lg`.
|
||||
|
||||
Use `!` only for a tightly scoped compatibility override after confirming the primitive API, data attributes, and selector structure cannot express the state.
|
||||
|
||||
## Focus Details
|
||||
|
||||
Flag focus rings attached to the wrong element. For example, Base UI `Slider.Thumb` focuses an internal `input[type=range]`, so the visible thumb wrapper needs `has-[:focus-visible]` rather than direct wrapper `focus-visible`.
|
||||
@ -1,45 +1,78 @@
|
||||
# Rule Catalog — Performance
|
||||
# Performance Rules
|
||||
|
||||
## React Flow data usage
|
||||
Review performance only where there is realistic impact. Do not request `memo`, `useMemo`, `useCallback`, virtualization, or caching as style preferences.
|
||||
|
||||
IsUrgent: True
|
||||
Category: Performance
|
||||
## Async Waterfalls
|
||||
|
||||
### Description
|
||||
Flag:
|
||||
|
||||
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
|
||||
- Awaiting remote feature flags or fetches before checking cheap synchronous conditions.
|
||||
- Sequential awaits for independent operations.
|
||||
- API routes or server components starting requests late when they could start early.
|
||||
- Nested per-item fetches running serially when each item can fetch in parallel.
|
||||
- Suspense boundaries that force the whole page to wait when a lower boundary could stream or isolate loading.
|
||||
|
||||
## Complex prop stability
|
||||
Prefer `Promise.all` for independent work and branch-local awaits for conditionally needed data.
|
||||
|
||||
IsUrgent: False
|
||||
Category: Performance
|
||||
## Bundle Size
|
||||
|
||||
### Description
|
||||
Flag:
|
||||
|
||||
Only require stable object, array, or map props when there is a clear reason: the child is memoized, the value participates in effect/query dependencies, the value is part of a stable-reference API contract, or profiling/local behavior shows avoidable re-renders. Do not request `useMemo` for every inline object by default; `how-to-write-component` treats memoization as a targeted optimization.
|
||||
- Barrel imports from heavy libraries or `@langgenius/dify-ui`.
|
||||
- Dynamic paths that prevent static trace analysis.
|
||||
- Heavy components loaded eagerly when hidden behind a dialog, tab, command, or feature activation.
|
||||
- Analytics, logging, editor, visualization, or third-party SDK code loaded before it is needed.
|
||||
- Feature-local optional modules imported at top level only for rare flows.
|
||||
|
||||
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
|
||||
Use direct imports and `next/dynamic` where the user-visible path benefits.
|
||||
|
||||
Risky:
|
||||
## Server Rendering
|
||||
|
||||
```tsx
|
||||
<HeavyComp
|
||||
config={{
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}}
|
||||
/>
|
||||
```
|
||||
Flag:
|
||||
|
||||
Better when stable identity matters:
|
||||
- Request-specific mutable state stored at module scope in SSR/RSC paths.
|
||||
- Large duplicate data serialized across RSC/client boundaries.
|
||||
- Static I/O repeated per request when it could be hoisted safely.
|
||||
- Cross-request cache without a bounded invalidation strategy.
|
||||
- Server actions lacking API-route-equivalent auth checks.
|
||||
|
||||
```tsx
|
||||
const config = useMemo(() => ({
|
||||
provider: ...,
|
||||
detail: ...
|
||||
}), [provider, detail]);
|
||||
Use request-scoped deduplication such as `React.cache()` when repeated server reads in one request are the problem.
|
||||
|
||||
<HeavyComp
|
||||
config={config}
|
||||
/>
|
||||
```
|
||||
## Re-rendering
|
||||
|
||||
Flag:
|
||||
|
||||
- Effects or subscriptions reading broad state when a derived boolean or narrower selector is enough.
|
||||
- Components defined inside components.
|
||||
- Derived rendering state stored in state/effects.
|
||||
- Non-primitive default props recreated for memoized children.
|
||||
- Expensive work recalculated on every render where it affects real interaction cost.
|
||||
- High-frequency transient values stored in state when refs or CSS variables would avoid render loops.
|
||||
|
||||
Do not flag simple primitive expressions wrapped or not wrapped in `useMemo`; prefer no memo for simple work.
|
||||
|
||||
Require stable object/array/function identity only when:
|
||||
|
||||
- The child is memoized and identity affects renders.
|
||||
- The value is an effect/query dependency.
|
||||
- A library API requires stable references.
|
||||
- Profiling or local behavior shows avoidable re-rendering.
|
||||
|
||||
## DOM, Lists, And Rendering
|
||||
|
||||
Flag:
|
||||
|
||||
- Layout reads in render (`getBoundingClientRect`, `offset*`, `scrollTop`).
|
||||
- Interleaved DOM reads/writes that can cause layout thrashing.
|
||||
- Large lists rendering without virtualization, pagination, or `content-visibility`.
|
||||
- SVG/animation code animating expensive properties when transform/opacity would work.
|
||||
- `transition-all`.
|
||||
- Long-running non-critical browser work performed immediately instead of idle/deferred scheduling.
|
||||
|
||||
## React Flow
|
||||
|
||||
For workflow React Flow components, keep this Dify-specific rule:
|
||||
|
||||
- UI consumption should use React Flow hooks such as `useNodes` / `useEdges`.
|
||||
- Callback-only reads or mutations can use `useStoreApi`.
|
||||
- Node components under `web/app/components/workflow/nodes/[nodeName]/node.tsx` must not depend on workflow stores that are absent in RAG Pipe template rendering.
|
||||
|
||||
72
.agents/skills/frontend-code-review/references/testing.md
Normal file
72
.agents/skills/frontend-code-review/references/testing.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Testing Review Rules
|
||||
|
||||
Use these rules when reviewing test files, testability of changed code, or risky frontend changes that should have tests.
|
||||
|
||||
## Missing Coverage
|
||||
|
||||
Flag missing tests when the change affects:
|
||||
|
||||
- User-visible behavior, navigation, form submission, validation, permissions, or loading/error/empty states.
|
||||
- Query/mutation cache behavior.
|
||||
- Accessibility-critical behavior such as labels, keyboard flow, focus, disabled state, or popup reachability.
|
||||
- URL state parsing/serialization.
|
||||
- Storage persistence or one-shot signals.
|
||||
- Regression-prone workflow or generated contract migration paths.
|
||||
|
||||
Do not request tests for purely mechanical renames or styling-only changes unless the styling affects layout, focus, or interaction.
|
||||
|
||||
## Selectors
|
||||
|
||||
Flag:
|
||||
|
||||
- `getByTestId` used where role, label, text, placeholder, landmark, or scoped dialog/menu queries are available.
|
||||
- Production `data-testid` added only to satisfy tests.
|
||||
- Assertions against decorative icons rather than the named control.
|
||||
- Tests that cannot find controls semantically but leave broken markup unchanged.
|
||||
|
||||
Prefer `getByRole` with accessible name, then `getByLabelText`, `getByPlaceholderText`, `getByText`, and `within(...)`.
|
||||
|
||||
## Mocking
|
||||
|
||||
Flag:
|
||||
|
||||
- Mocking `@langgenius/dify-ui/*` primitives.
|
||||
- Mocking `@/app/components/base/*` components when the real component is practical.
|
||||
- Mocking sibling or child components in the same directory for integration behavior.
|
||||
- Mocks that do not match the real component's conditional rendering.
|
||||
- Module-level mock state not reset in `beforeEach`.
|
||||
- `vi.clearAllMocks()` in `afterEach` instead of `beforeEach`.
|
||||
|
||||
Use real project components for integration behavior. Mock APIs, `next/navigation`, browser shims, or complex providers only when setup would dominate the test.
|
||||
|
||||
## Behavior
|
||||
|
||||
Flag:
|
||||
|
||||
- Tests inspecting implementation details instead of user-observable behavior.
|
||||
- Assertions that hardcode brittle copy when pattern matching or semantic roles would express behavior better.
|
||||
- Fake timers used without real timing behavior.
|
||||
- Async assertions missing `await`, `findBy*`, or `waitFor`.
|
||||
- Test data missing required fields because inline partial objects bypass real types.
|
||||
|
||||
Use typed factory functions with complete defaults and partial overrides.
|
||||
|
||||
## URL State
|
||||
|
||||
For `nuqs` or query-state hooks, flag tests that:
|
||||
|
||||
- Mock URL state when URL synchronization is the behavior under review.
|
||||
- Do not test parser serialize/parse round trips for custom parsers.
|
||||
- Do not assert default-clearing behavior when defaults should be removed from the URL.
|
||||
|
||||
Prefer shared `NuqsTestingAdapter` helpers when available.
|
||||
|
||||
## Organization
|
||||
|
||||
Flag:
|
||||
|
||||
- Component/hook/util tests outside sibling `__tests__/` directories.
|
||||
- Directory-level reviews that test only `index.tsx` while other files in scope contain behavior.
|
||||
- Large test files with repeated setup that should use local builders.
|
||||
|
||||
When a component is very complex, prefer a refactor finding before asking for exhaustive tests.
|
||||
@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
1
.claude/skills/how-to-write-component
Symbolic link
1
.claude/skills/how-to-write-component
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/how-to-write-component
|
||||
13
.github/CODEOWNERS
vendored
13
.github/CODEOWNERS
vendored
@ -15,6 +15,10 @@
|
||||
# Agents
|
||||
/.agents/skills/ @hyoban
|
||||
|
||||
# Packages
|
||||
/packages/ @lyzno1
|
||||
/packages/contracts/ @crazywoola @laipz8200
|
||||
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
@ -143,6 +147,14 @@
|
||||
# Frontend
|
||||
/web/ @iamjoel
|
||||
|
||||
# Frontend - Platform and Features
|
||||
/web/config/ @lyzno1
|
||||
/web/contract/ @lyzno1
|
||||
/web/env.ts @lyzno1
|
||||
/web/features/ @lyzno1
|
||||
/web/hooks/ @lyzno1
|
||||
/web/scripts/gen-icons.mjs @lyzno1
|
||||
|
||||
# Frontend - Web Tests
|
||||
/.github/workflows/web-tests.yml @iamjoel
|
||||
|
||||
@ -253,7 +265,6 @@
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
/web/utils/format.ts @iamjoel @zxhlyh
|
||||
/web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
/web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
/web/app/components/billing/ @iamjoel @zxhlyh
|
||||
|
||||
14
.github/workflows/api-tests.yml
vendored
14
.github/workflows/api-tests.yml
vendored
@ -29,13 +29,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -91,13 +91,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -142,13 +142,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -195,7 +195,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
|
||||
24
.github/workflows/autofix.yml
vendored
24
.github/workflows/autofix.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
run: echo "autofix.ci updates pull request branches, not merge group refs."
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
@ -51,13 +51,22 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- name: Check dify-agent inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: dify-agent-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
dify-agent/**/*.py
|
||||
dify-agent/pyproject.toml
|
||||
dify-agent/uv.lock
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@ -76,6 +85,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd dify-agent
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format .
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: count migration progress
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
14
.github/workflows/build-push.yml
vendored
14
.github/workflows/build-push.yml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
@ -78,13 +78,13 @@ jobs:
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6.1.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: depot/build-push-action@98e78adca7817480b8185f474a400b451d74e287 # v1.18.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
@ -124,10 +124,10 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7.2.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
@ -156,14 +156,14 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
uses: docker/metadata-action@80c7e94dd9b9319bd5eb7a0e0fe9291e23a2a2e9 # v6.1.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
|
||||
182
.github/workflows/cli-release.yml
vendored
182
.github/workflows/cli-release.yml
vendored
@ -2,87 +2,165 @@ name: CLI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- 'difyctl-v*'
|
||||
inputs:
|
||||
release_tag:
|
||||
description: Dify release tag to attach difyctl assets to (blank = latest stable)
|
||||
required: false
|
||||
type: string
|
||||
workflow_call:
|
||||
inputs:
|
||||
release_tag:
|
||||
description: Dify release tag to attach difyctl assets to (blank = latest stable)
|
||||
required: false
|
||||
type: string
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
group: difyctl-release
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: build standalone binaries (all targets)
|
||||
validate:
|
||||
name: validate manifest + resolve target Dify release
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
outputs:
|
||||
dify_tag: ${{ steps.resolve.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Export manifest to env
|
||||
run: node scripts/release-naming.mjs github-env >> "$GITHUB_ENV"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Resolve target Dify release
|
||||
id: resolve
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
EVENT_TAG: ${{ github.event.release.tag_name }}
|
||||
INPUT_TAG: ${{ inputs.release_tag }}
|
||||
run: |
|
||||
if [ -n "$EVENT_TAG" ]; then
|
||||
tag="$EVENT_TAG"
|
||||
elif [ -n "$INPUT_TAG" ]; then
|
||||
tag="$INPUT_TAG"
|
||||
else
|
||||
tag="$(gh api "repos/${GITHUB_REPOSITORY}/releases/latest" --jq .tag_name)"
|
||||
fi
|
||||
if [ -z "$tag" ]; then
|
||||
echo "::error::could not resolve a target Dify release tag"
|
||||
exit 1
|
||||
fi
|
||||
if ! gh release view "$tag" --repo "$GITHUB_REPOSITORY" >/dev/null 2>&1; then
|
||||
echo "::error::target Dify release ${tag} not found"
|
||||
exit 1
|
||||
fi
|
||||
echo "dify_tag=${tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "::notice::target Dify release ${tag}"
|
||||
|
||||
- name: Compatibility check
|
||||
env:
|
||||
DIFY_TAG: ${{ steps.resolve.outputs.dify_tag }}
|
||||
run: node scripts/release-naming.mjs compat-check "$DIFY_TAG"
|
||||
|
||||
- name: Reject duplicate difyctl version
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
if gh api "repos/${GITHUB_REPOSITORY}/git/ref/tags/${difyctlTag}" >/dev/null 2>&1; then
|
||||
echo "::error::difyctl ${version} already released (tag ${difyctlTag} exists); bump cli/package.json version"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
release:
|
||||
name: build + attach standalone binaries (all targets)
|
||||
needs: validate
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
env:
|
||||
DIFY_TAG: ${{ needs.validate.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Enable cross-arch native prebuilds
|
||||
working-directory: ./
|
||||
run: cat cli/scripts/cross-arch.pnpm.yaml >> pnpm-workspace.yaml
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Export manifest to env
|
||||
run: node scripts/release-naming.mjs github-env >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||
uses: oven-sh/setup-bun@0c5077e51419868618aeaa5fe8019c62421857d6 # v2.0.2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Install cross-arch native prebuilds
|
||||
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||
# so `bun build --compile` can embed the right .node into each target.
|
||||
working-directory: ./
|
||||
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||
bun-version-file: cli/.bun-version
|
||||
|
||||
- name: Compile standalone binaries (all targets)
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
run: |
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build:bin
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
run: scripts/release-write-checksums.sh
|
||||
|
||||
- name: Publish GitHub Release
|
||||
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||
with:
|
||||
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||
generate_release_notes: true
|
||||
fail_on_unmatched_files: true
|
||||
files: |
|
||||
cli/dist/bin/difyctl-v*
|
||||
- name: Attach difyctl assets to Dify release
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
gh release upload "$DIFY_TAG" dist/bin/${tagPrefix}* \
|
||||
--repo "$GITHUB_REPOSITORY" --clobber
|
||||
|
||||
- name: Prune stale difyctl assets
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
new_set="$(cd dist/bin && ls ${tagPrefix}*)"
|
||||
gh release view "$DIFY_TAG" --repo "$GITHUB_REPOSITORY" \
|
||||
--json assets --jq '.assets[].name' \
|
||||
| { grep -E "^${tagPrefix}" || true; } \
|
||||
| while IFS= read -r name; do
|
||||
if ! printf '%s\n' "$new_set" | grep -qxF -- "$name"; then
|
||||
echo "::notice::pruning stale asset ${name}"
|
||||
gh release delete-asset "$DIFY_TAG" "$name" \
|
||||
--repo "$GITHUB_REPOSITORY" --yes
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Create provenance tag
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
ref="refs/tags/${difyctlTag}"
|
||||
sha="$(git rev-parse HEAD)"
|
||||
status="$(gh api -X POST "repos/${GITHUB_REPOSITORY}/git/refs" \
|
||||
-f ref="$ref" -f sha="$sha" --silent --include 2>/dev/null \
|
||||
| awk 'NR==1 {print $2; exit}' || true)"
|
||||
case "$status" in
|
||||
201) echo "::notice::created ${ref}" ;;
|
||||
422) echo "::notice::tag ${ref} already exists; skipping (immutable)" ;;
|
||||
*) echo "::error::provenance tag ${ref} not created (HTTP ${status:-unknown})"; exit 1 ;;
|
||||
esac
|
||||
|
||||
2
.github/workflows/cli-smoke.yml
vendored
2
.github/workflows/cli-smoke.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
8
.github/workflows/cli-tests.yml
vendored
8
.github/workflows/cli-tests.yml
vendored
@ -30,19 +30,23 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Validate release manifest
|
||||
if: matrix.os == 'depot-ubuntu-24.04'
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
|
||||
12
.github/workflows/db-migration-test.yml
vendored
12
.github/workflows/db-migration-test.yml
vendored
@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
cp envs/middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@11beaa1c2dae4e8ed7b1665aa074723b6cecb0e4 # v3.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@11beaa1c2dae4e8ed7b1665aa074723b6cecb0e4 # v3.0.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
4
.github/workflows/deploy-saas.yml
vendored
4
.github/workflows/deploy-saas.yml
vendored
@ -21,8 +21,8 @@ jobs:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}
|
||||
|
||||
6
.github/workflows/docker-build.yml
vendored
6
.github/workflows/docker-build.yml
vendored
@ -53,7 +53,7 @@ jobs:
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: depot/build-push-action@98e78adca7817480b8185f474a400b451d74e287 # v1.18.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
@ -77,10 +77,10 @@ jobs:
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 # v4.1.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: docker/build-push-action@f9f3042f7e2789586610d6e8b85c8f03e5195baf # v7.2.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
|
||||
2
.github/workflows/hotfix-cherry-pick.yml
vendored
2
.github/workflows/hotfix-cherry-pick.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
name: Require cherry-pick provenance
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/main-ci.yml
vendored
2
.github/workflows/main-ci.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: changes
|
||||
with:
|
||||
|
||||
4
.github/workflows/pyrefly-diff.yml
vendored
4
.github/workflows/pyrefly-diff.yml
vendored
@ -17,12 +17,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
@ -21,10 +21,10 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Checkout default branch (trusted code)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -17,12 +17,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
- uses: actions/stale@eb5cf3af3ac0a1aa4c9c45633dd1ae542a27a899 # v10.3.0
|
||||
with:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
|
||||
71
.github/workflows/style.yml
vendored
71
.github/workflows/style.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -95,6 +95,51 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip
|
||||
|
||||
ts-common-style:
|
||||
name: TS Common
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
checks: write
|
||||
pull-requests: read
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
cli/**
|
||||
e2e/**
|
||||
sdks/nodejs-client/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
eslint.config.mjs
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
@ -105,28 +150,14 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
- name: Style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
- name: Type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
@ -140,7 +171,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@1dc994ee7a008f0ecc866d9ac23ef036b7229f84 # v1.0.127
|
||||
uses: anthropics/claude-code-action@fbda2eb1bdc90d319b8d853f5deb53bca199a7c1 # v1.0.140
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
4
.github/workflows/vdb-tests-full.yml
vendored
4
.github/workflows/vdb-tests-full.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -36,7 +36,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
4
.github/workflows/vdb-tests.yml
vendored
4
.github/workflows/vdb-tests.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -28,7 +28,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
10
.github/workflows/web-tests.yml
vendored
10
.github/workflows/web-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
@ -102,7 +102,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f # v7.0.0
|
||||
with:
|
||||
directory: packages/dify-ui/coverage
|
||||
flags: dify-ui
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -259,3 +259,6 @@ scripts/stress-test/reports/
|
||||
.qoder/*
|
||||
.context/
|
||||
.eslintcache
|
||||
|
||||
# Vitest local reports
|
||||
web/.vitest-reports/
|
||||
|
||||
27
SECURITY.md
Normal file
27
SECURITY.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
|
||||
|
||||
https://github.com/langgenius/dify/security/advisories/new
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
|
||||
|
||||
When submitting a report, include as much relevant information as you can safely provide, such as:
|
||||
|
||||
- A description of the vulnerability
|
||||
- Steps to reproduce, if safe to share privately
|
||||
- Affected components, versions, or configurations
|
||||
- Potential impact
|
||||
- Any suggested mitigation or fix, if available
|
||||
|
||||
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
|
||||
|
||||
## Public Disclosure
|
||||
|
||||
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
|
||||
|
||||
## Security Updates
|
||||
|
||||
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.
|
||||
@ -17,7 +17,7 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
git g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -34,6 +34,7 @@ from clients.agent_backend.request_builder import (
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
@ -49,6 +50,7 @@ __all__ = [
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendAgentAppRunInput",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
|
||||
@ -30,6 +30,7 @@ from dify_agent.layers.execution_context import (
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
@ -45,8 +46,10 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
DIFY_SHELL_LAYER_ID = "shell"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
@ -166,6 +169,10 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
@ -181,9 +188,154 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendAgentAppRunInput(BaseModel):
|
||||
"""Inputs to build one Agent App conversation-turn run request.
|
||||
|
||||
Unlike the workflow-node input there is no workflow-node-job prompt and no
|
||||
previous-node context: the user prompt is the chat message, and multi-turn
|
||||
continuity comes from ``session_snapshot`` + the history layer keyed by the
|
||||
conversation.
|
||||
"""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "agent_app"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
|
||||
"""Build an Agent App conversation-turn run request.
|
||||
|
||||
Layer graph: optional Agent Soul system prompt → user prompt →
|
||||
execution context → optional history (multi-turn) → LLM → optional
|
||||
plugin tools → optional structured output. Mirrors the workflow-node
|
||||
layer ordering minus the workflow-job / previous-node prompt.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=AGENT_APP_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
@ -302,6 +454,18 @@ class AgentBackendRunRequestBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
|
||||
135
api/clients/agent_backend/workspace_files_client.py
Normal file
135
api/clients/agent_backend/workspace_files_client.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""API-side client for the agent backend's read-only workspace file endpoints.
|
||||
|
||||
The agent backend exposes ``/workspaces/{session_id}/files{,/preview,/download}``
|
||||
to inspect a shell-layer sandbox workspace. This thin synchronous client proxies
|
||||
those reads for the console FS inspector and normalizes transport/HTTP failures
|
||||
into the API backend's ``AgentBackendError`` boundary, preserving the backend's
|
||||
status code and ``{code, message}`` detail so the controller can relay them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
|
||||
_DEFAULT_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class WorkspaceFileEntry(BaseModel):
|
||||
"""One entry in a workspace directory listing."""
|
||||
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResult(BaseModel):
|
||||
"""Directory listing of a workspace path."""
|
||||
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntry]
|
||||
truncated: bool
|
||||
|
||||
|
||||
class WorkspacePreviewResult(BaseModel):
|
||||
"""Inline preview of a workspace file."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkspaceDownloadResult:
|
||||
"""Decoded bytes of a workspace file for download."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
content: bytes
|
||||
|
||||
|
||||
class WorkspaceFilesBackendClient:
|
||||
"""Synchronous proxy to the agent backend workspace file endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._transport = transport
|
||||
|
||||
def list_files(self, session_id: str, path: str) -> WorkspaceListResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files", params={"path": path})
|
||||
return WorkspaceListResult.model_validate(data)
|
||||
|
||||
def preview(self, session_id: str, path: str) -> WorkspacePreviewResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/preview", params={"path": path})
|
||||
return WorkspacePreviewResult.model_validate(data)
|
||||
|
||||
def download(self, session_id: str, path: str) -> WorkspaceDownloadResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/download", params={"path": path})
|
||||
encoded = data.get("content_base64")
|
||||
if not isinstance(encoded, str):
|
||||
raise AgentBackendHTTPError("agent backend download response missing content", status_code=502, detail=data)
|
||||
try:
|
||||
content = base64.b64decode(encoded, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend returned undecodable download content", status_code=502, detail=str(exc)
|
||||
) from exc
|
||||
size = data.get("size")
|
||||
return WorkspaceDownloadResult(
|
||||
path=str(data.get("path", path)),
|
||||
size=int(size) if isinstance(size, (int, float)) else len(content),
|
||||
truncated=bool(data.get("truncated")),
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _get(self, route: str, *, params: dict[str, str]) -> dict[str, object]:
|
||||
url = f"{self._base_url}{route}"
|
||||
try:
|
||||
with httpx.Client(timeout=self._timeout, transport=self._transport, trust_env=False) as client:
|
||||
response = client.get(url, params=params)
|
||||
except httpx.HTTPError as exc:
|
||||
raise AgentBackendTransportError(f"failed to reach agent backend workspace endpoint: {exc}") from exc
|
||||
if response.status_code >= 400:
|
||||
detail: object
|
||||
try:
|
||||
detail = response.json().get("detail", response.text)
|
||||
except ValueError:
|
||||
detail = response.text
|
||||
raise AgentBackendHTTPError(
|
||||
f"agent backend workspace request failed ({response.status_code})",
|
||||
status_code=response.status_code,
|
||||
detail=detail,
|
||||
)
|
||||
body = response.json()
|
||||
if not isinstance(body, dict):
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend workspace response was not an object", status_code=502, detail=body
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WorkspaceDownloadResult",
|
||||
"WorkspaceFileEntry",
|
||||
"WorkspaceFilesBackendClient",
|
||||
"WorkspaceListResult",
|
||||
"WorkspacePreviewResult",
|
||||
]
|
||||
@ -4,6 +4,12 @@ CLI command modules extracted from `commands.py`.
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .data_migration import (
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
)
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -26,7 +32,12 @@ from .retention import (
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .system import (
|
||||
convert_to_agent_apps,
|
||||
fix_app_site_missing,
|
||||
reset_encrypt_key_pair,
|
||||
upgrade_db,
|
||||
)
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
@ -48,10 +59,13 @@ __all__ = [
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"export_migration_data",
|
||||
"export_migration_data_template",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"import_migration_data",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
@ -59,6 +73,7 @@ __all__ = [
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"migration_data_wizard",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
|
||||
754
api/commands/data_migration.py
Normal file
754
api/commands/data_migration.py
Normal file
@ -0,0 +1,754 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.model import App
|
||||
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
|
||||
from services.data_migration.entities import (
|
||||
DependencyKind,
|
||||
ImportOptions,
|
||||
MigrationDataError,
|
||||
ReportContext,
|
||||
ResourceReportItem,
|
||||
)
|
||||
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
|
||||
from services.data_migration.import_service import ImportRequest, MigrationImportService
|
||||
from services.data_migration.package_service import MigrationPackageService
|
||||
from services.data_migration.report_service import MigrationReportService
|
||||
|
||||
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
|
||||
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
|
||||
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
|
||||
WizardToolMap = dict[str, dict[str, str | None]]
|
||||
WizardToolSelection = dict[str, list[str]]
|
||||
|
||||
|
||||
def _scripted_export_template() -> dict[str, Any]:
|
||||
return {
|
||||
"source_tenant": {
|
||||
"mode": "single",
|
||||
"id": "",
|
||||
"name": "admin's Workspace",
|
||||
},
|
||||
"apps": {
|
||||
"modes": ["workflow", "advanced-chat"],
|
||||
"ids": [],
|
||||
"all": True,
|
||||
},
|
||||
"include_referenced_tools": True,
|
||||
"additional_tools": {
|
||||
"api_tools": [],
|
||||
"workflow_tools": [],
|
||||
"mcp_tools": [],
|
||||
},
|
||||
"include_secrets": False,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": False,
|
||||
"id_strategy": "preserve-id",
|
||||
"conflict_strategy": "fail",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to write the export config JSON template. Prints to stdout when omitted.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
|
||||
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
|
||||
if output_file is None:
|
||||
click.echo(template_json, nl=False)
|
||||
return
|
||||
path = Path(output_file)
|
||||
if path.exists() and not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
path.write_text(template_json)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to export config JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file), ("--output", output_file))
|
||||
assert input_file is not None
|
||||
assert output_file is not None
|
||||
raw_config = _load_json_object(input_file, "Export config")
|
||||
selection = ExportConfigParser().parse(raw_config)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
@click.command("import-app-migration", help="Import a versioned migration data package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
|
||||
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
|
||||
@click.option(
|
||||
"--id-strategy",
|
||||
default=None,
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
help="Override package ID strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--conflict-strategy",
|
||||
default=None,
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
help="Override package conflict strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
|
||||
default=None,
|
||||
help="Override package app API token creation behavior.",
|
||||
)
|
||||
def import_migration_data(
|
||||
input_file: str | None,
|
||||
target_tenant: str | None,
|
||||
operator_email: str | None,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file))
|
||||
assert input_file is not None
|
||||
package = MigrationPackageService().load_package(input_file)
|
||||
result = MigrationImportService().import_package(
|
||||
ImportRequest(
|
||||
package=package,
|
||||
cli_target_tenant=target_tenant,
|
||||
operator_email=operator_email,
|
||||
options_override=_build_options_override(
|
||||
package.metadata.import_options,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
create_app_api_token_on_import=create_app_api_token_on_import,
|
||||
),
|
||||
)
|
||||
)
|
||||
_render_report(result.report_items, context=result.report_context)
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
|
||||
normalized = raw.strip().lower()
|
||||
if normalized == "all":
|
||||
return values
|
||||
|
||||
selected: list[str] = []
|
||||
for part in raw.split(","):
|
||||
stripped = part.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
index = int(stripped)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
|
||||
if index < 1 or index > len(values):
|
||||
raise click.ClickException(f"Selection index out of range: {index}")
|
||||
selected.append(values[index - 1])
|
||||
return list(dict.fromkeys(selected))
|
||||
|
||||
|
||||
def _print_wizard_step(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"==== {title} ====")
|
||||
|
||||
|
||||
def _print_wizard_substep(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"-- {title} --")
|
||||
|
||||
|
||||
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
|
||||
def migration_data_wizard() -> None:
|
||||
try:
|
||||
tenant = _prompt_source_tenant()
|
||||
apps = _eligible_apps_for_tenant(tenant.id)
|
||||
app_ids = _prompt_app_ids(apps)
|
||||
_print_wizard_step("Referenced Tools")
|
||||
include_referenced_tools = click.confirm(
|
||||
"Automatically export tools referenced by selected apps? [y/n, default: y]",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
|
||||
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
|
||||
_print_auto_tools(auto_tools)
|
||||
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
|
||||
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
|
||||
_print_wizard_step("Output")
|
||||
output_file, overwrite = _prompt_output_file()
|
||||
|
||||
selection = ExportConfigParser().parse(
|
||||
{
|
||||
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
|
||||
"apps": {"ids": app_ids, "all": False},
|
||||
"include_referenced_tools": include_referenced_tools,
|
||||
"additional_tools": additional_tools,
|
||||
"include_secrets": include_secrets,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": create_tokens,
|
||||
"id_strategy": id_strategy,
|
||||
"conflict_strategy": conflict_strategy,
|
||||
},
|
||||
}
|
||||
)
|
||||
_confirm_wizard_summary(
|
||||
tenant_name=tenant.name,
|
||||
app_names=[app.name for app in apps if app.id in set(app_ids)],
|
||||
auto_tools=auto_tools,
|
||||
additional_tools=additional_tools,
|
||||
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
|
||||
include_referenced_tools=include_referenced_tools,
|
||||
include_secrets=include_secrets,
|
||||
create_tokens=create_tokens,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
output_file=output_file,
|
||||
)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_print_wizard_step("Report")
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def _load_json_object(path: str, label: str) -> dict[str, Any]:
|
||||
try:
|
||||
with Path(path).open(encoding="utf-8") as file:
|
||||
raw = json.load(file)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise MigrationDataError(f"{label} JSON must be an object.")
|
||||
return raw
|
||||
|
||||
|
||||
def _require_options(*options: tuple[str, object | None]) -> None:
|
||||
missing_options = [name for name, value in options if value is None]
|
||||
if missing_options:
|
||||
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
|
||||
|
||||
|
||||
def _build_options_override(
|
||||
package_options: ImportOptions,
|
||||
*,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> ImportOptions | None:
|
||||
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
|
||||
return None
|
||||
return ImportOptions.from_mapping(
|
||||
{
|
||||
"id_strategy": id_strategy or package_options.id_strategy,
|
||||
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
|
||||
"create_app_api_token_on_import": (
|
||||
create_app_api_token_on_import
|
||||
if create_app_api_token_on_import is not None
|
||||
else package_options.create_app_api_token_on_import
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _prompt_source_tenant() -> Tenant:
|
||||
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
|
||||
if not tenants:
|
||||
raise MigrationDataError("No tenants found.")
|
||||
|
||||
_print_wizard_step("Source Tenant")
|
||||
click.echo("Source tenants:")
|
||||
for index, tenant in enumerate(tenants, 1):
|
||||
click.echo(f"{index}. {tenant.name} ({tenant.id})")
|
||||
|
||||
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
|
||||
if tenant_index < 1 or tenant_index > len(tenants):
|
||||
raise click.ClickException(f"Selection index out of range: {tenant_index}")
|
||||
return tenants[tenant_index - 1]
|
||||
|
||||
|
||||
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
|
||||
return list(
|
||||
db.session.scalars(
|
||||
sa.select(App)
|
||||
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
|
||||
.order_by(App.name.asc())
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def _prompt_app_ids(apps: list[App]) -> list[str]:
|
||||
if not apps:
|
||||
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
|
||||
|
||||
_print_wizard_step("App Selection")
|
||||
click.echo("Currently supported app types: workflow and chatflow.")
|
||||
click.echo("Workflow/chatflow apps:")
|
||||
for index, app in enumerate(apps, 1):
|
||||
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
|
||||
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
|
||||
app_ids = parse_index_selection(
|
||||
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
|
||||
[app.id for app in apps],
|
||||
)
|
||||
selected_apps = [app for app in apps if app.id in set(app_ids)]
|
||||
click.echo("Selected apps:")
|
||||
for app in selected_apps:
|
||||
click.echo(f"- {app.name} ({app.id})")
|
||||
return app_ids
|
||||
|
||||
|
||||
def _prompt_import_options() -> tuple[bool, bool, str, str]:
|
||||
_print_wizard_step("Import Options")
|
||||
_print_wizard_substep("Secrets")
|
||||
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
|
||||
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
|
||||
click.echo("If you choose no, credentials are omitted or masked,")
|
||||
click.echo("and MCP providers are exported as dependency metadata only.")
|
||||
click.echo("Treat the output JSON as sensitive if you choose yes.")
|
||||
include_secrets = click.confirm(
|
||||
"Include secrets in output JSON? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("App API Tokens")
|
||||
click.echo("When enabled, import will create an app API token if the imported app has none,")
|
||||
click.echo("or reuse an existing app API token if one already exists.")
|
||||
create_tokens = click.confirm(
|
||||
"Create or reuse app API tokens during import? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("ID Strategy")
|
||||
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
|
||||
click.echo("or use target-generated IDs.")
|
||||
click.echo("preserve-id: keep source IDs where the target service supports it.")
|
||||
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
|
||||
id_strategy = click.prompt(
|
||||
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
default="preserve-id",
|
||||
show_default=True,
|
||||
)
|
||||
_print_wizard_substep("Conflict Strategy")
|
||||
click.echo("Conflict strategy controls what import does when a target resource already exists.")
|
||||
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
|
||||
click.echo("skip: keep the existing target resource and skip importing that resource.")
|
||||
click.echo("update: update the existing target resource in place.")
|
||||
conflict_strategy = click.prompt(
|
||||
"Import conflict strategy. Enter one of: fail, skip, update",
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
default="update",
|
||||
show_default=True,
|
||||
)
|
||||
return include_secrets, create_tokens, id_strategy, conflict_strategy
|
||||
|
||||
|
||||
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
|
||||
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
|
||||
if not include_referenced_tools:
|
||||
return auto_tools
|
||||
discovery_service = DependencyDiscoveryService()
|
||||
for app in apps:
|
||||
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
|
||||
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
|
||||
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
|
||||
for dependency in discovery_service.discover_from_dsl(dsl):
|
||||
if dependency.kind == DependencyKind.API_TOOL:
|
||||
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
|
||||
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
|
||||
dependency.provider_id
|
||||
)
|
||||
elif dependency.kind == DependencyKind.MCP_TOOL:
|
||||
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
return auto_tools
|
||||
|
||||
|
||||
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
|
||||
return {
|
||||
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
|
||||
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
|
||||
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [ApiToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(ApiToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [WorkflowToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(WorkflowToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [MCPToolProvider.name == name]
|
||||
if identifier:
|
||||
predicates.append(MCPToolProvider.server_identifier == identifier)
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(MCPToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_uuid_string(value: str | None) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
|
||||
_print_wizard_step("Automatically Discovered Tools")
|
||||
click.echo("Automatically discovered tools:")
|
||||
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
|
||||
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
|
||||
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
|
||||
|
||||
|
||||
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
|
||||
click.echo(label)
|
||||
if not values:
|
||||
click.echo("- none")
|
||||
return
|
||||
for name, identifier in sorted(values.items()):
|
||||
click.echo(f"- {_format_tool_name_id(name, identifier)}")
|
||||
|
||||
|
||||
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
|
||||
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
|
||||
_print_wizard_step("Additional Tools")
|
||||
if not click.confirm(
|
||||
"Export additional tools manually? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
):
|
||||
_print_final_tool_selection(auto_tools, selections, {})
|
||||
return selections
|
||||
manual_labels: dict[str, str] = {}
|
||||
api_tool_options = [
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["api_tools"] = _prompt_tool_category(
|
||||
"Custom API tools",
|
||||
api_tool_options,
|
||||
auto_tools=auto_tools["api_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
|
||||
workflow_tool_options = [
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["workflow_tools"] = _prompt_tool_category(
|
||||
"Workflow tools",
|
||||
workflow_tool_options,
|
||||
auto_tools=auto_tools["workflow_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
|
||||
mcp_tool_options = [
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["mcp_tools"] = _prompt_tool_category(
|
||||
"MCP tools",
|
||||
mcp_tool_options,
|
||||
auto_tools=auto_tools["mcp_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
|
||||
_print_final_tool_selection(auto_tools, selections, manual_labels)
|
||||
return selections
|
||||
|
||||
|
||||
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
if selected_tools["api_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id)
|
||||
.order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["api_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["workflow_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["workflow_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["mcp_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["mcp_tools"],
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
|
||||
selected = set(selected_values)
|
||||
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
|
||||
|
||||
|
||||
def _prompt_tool_category(
|
||||
label: str,
|
||||
options: list[tuple[str, str, str]],
|
||||
*,
|
||||
auto_tools: dict[str, str | None],
|
||||
) -> list[str]:
|
||||
if not options:
|
||||
click.echo(f"{label}: none")
|
||||
return []
|
||||
_print_wizard_step(label)
|
||||
for index, (value, name, detail) in enumerate(options, 1):
|
||||
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
|
||||
click.echo(f"{index}. {marker} {name} ({detail})")
|
||||
raw = click.prompt(
|
||||
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
|
||||
default="",
|
||||
show_default=cast(Any, "empty"),
|
||||
)
|
||||
if not raw.strip():
|
||||
return []
|
||||
return parse_index_selection(raw, [value for value, _, _ in options])
|
||||
|
||||
|
||||
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
|
||||
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
|
||||
|
||||
|
||||
def _print_final_tool_selection(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
_print_wizard_step("Final Tool Selection")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
|
||||
|
||||
def _print_tool_selection_body(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo("Final tools to export:")
|
||||
_print_final_tool_category(
|
||||
"Custom API tools",
|
||||
auto_tools["api_tools"],
|
||||
additional_tools["api_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category(
|
||||
"Workflow tools",
|
||||
auto_tools["workflow_tools"],
|
||||
additional_tools["workflow_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
|
||||
|
||||
|
||||
def _print_final_tool_category(
|
||||
label: str,
|
||||
auto_tools: dict[str, str | None],
|
||||
manual_values: list[str],
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo(label)
|
||||
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
|
||||
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
|
||||
lines.extend(
|
||||
f"- [manual] {manual_labels.get(value, value)}"
|
||||
for value in manual_values
|
||||
if value not in auto_tools and value not in auto_identifiers
|
||||
)
|
||||
if not lines:
|
||||
click.echo("- none")
|
||||
return
|
||||
for line in lines:
|
||||
click.echo(line)
|
||||
|
||||
|
||||
def _format_tool_name_id(name: str, identifier: str | None) -> str:
|
||||
if identifier and identifier != name:
|
||||
return f"{name}: {identifier}"
|
||||
return name
|
||||
|
||||
|
||||
def _confirm_wizard_summary(
|
||||
*,
|
||||
tenant_name: str,
|
||||
app_names: list[str],
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
include_referenced_tools: bool,
|
||||
include_secrets: bool,
|
||||
create_tokens: bool,
|
||||
id_strategy: str,
|
||||
conflict_strategy: str,
|
||||
output_file: str,
|
||||
) -> None:
|
||||
_print_wizard_step("Summary")
|
||||
click.echo("Migration export summary:")
|
||||
click.echo(f"source tenant: {tenant_name}")
|
||||
click.echo(f"selected apps: {len(app_names)}")
|
||||
for app_name in app_names:
|
||||
click.echo(f"- {app_name}")
|
||||
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
click.echo(f"include secrets: {str(include_secrets).lower()}")
|
||||
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
|
||||
click.echo(f"id strategy: {id_strategy}")
|
||||
click.echo(f"conflict strategy: {conflict_strategy}")
|
||||
click.echo(f"output path: {output_file}")
|
||||
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _prompt_output_file() -> tuple[str, bool]:
|
||||
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
output_file = click.prompt("Output path", default=default_output, show_default=True)
|
||||
if output_file.lower() in {"y", "yes", "n", "no"}:
|
||||
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
|
||||
overwrite = False
|
||||
if Path(output_file).exists():
|
||||
overwrite = click.confirm(
|
||||
"Output file exists. Overwrite? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
return output_file, overwrite
|
||||
|
||||
|
||||
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
|
||||
if context is None:
|
||||
return ReportContext(output_path=output_path)
|
||||
return ReportContext(
|
||||
output_path=output_path,
|
||||
source_scope=context.source_scope,
|
||||
selected_app_count=context.selected_app_count,
|
||||
include_secrets=context.include_secrets,
|
||||
target_tenant=context.target_tenant,
|
||||
operator_email=context.operator_email,
|
||||
app_api_tokens_created=context.app_api_tokens_created,
|
||||
app_api_tokens_reused=context.app_api_tokens_reused,
|
||||
id_mapping_count=context.id_mapping_count,
|
||||
id_mappings=context.id_mappings,
|
||||
)
|
||||
|
||||
|
||||
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
|
||||
for line in MigrationReportService().render(report_items, context=context):
|
||||
click.echo(line)
|
||||
@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
Migrate annotation data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
Migrate vector database data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
|
||||
@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
|
||||
@ -21,3 +21,13 @@ class AgentBackendConfig(BaseSettings):
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
|
||||
AGENT_SHELL_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Inject the dify.shell layer (sandboxed bash workspace) into Agent runs. "
|
||||
"Requires the agent backend to be wired with a shellctl entrypoint; keep it "
|
||||
"off until shellctl is deployed, otherwise every agent run that includes the "
|
||||
"shell layer will fail."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -949,6 +949,11 @@ class AuthConfig(BaseSettings):
|
||||
default=60,
|
||||
)
|
||||
|
||||
DEVICE_FLOW_APPROVE_RATE_LIMIT_PER_HOUR: PositiveInt = Field(
|
||||
description="Max device-flow approve requests per session per hour on /openapi/oauth/device/approve.",
|
||||
default=10,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new Agent App type). The runtime model / prompt / tools
|
||||
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
|
||||
# template; create_app still creates a model-less app_model_config row to hold
|
||||
# app-level presentation features (opener, follow-up, citations, ...).
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,10 +1,40 @@
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
action: str
|
||||
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@ These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
from typing import Any, Literal, NotRequired, Protocol, TypedDict
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
@ -35,6 +36,14 @@ QueryParamDoc = TypedDict(
|
||||
},
|
||||
)
|
||||
|
||||
JsonResponseWithStatus = tuple[dict[str, Any], int]
|
||||
|
||||
|
||||
class QueryArgs(Protocol):
|
||||
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
|
||||
|
||||
def getlist(self, key: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
@ -167,6 +176,58 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
return params
|
||||
|
||||
|
||||
def query_params_from_request[ModelT: BaseModel](
|
||||
model: type[ModelT],
|
||||
*,
|
||||
list_fields: Iterable[str] = (),
|
||||
args: QueryArgs | None = None,
|
||||
use_defaults_for_malformed_ints: bool = False,
|
||||
) -> ModelT:
|
||||
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
|
||||
|
||||
Repeated params need explicit ``getlist()`` handling because Werkzeug's
|
||||
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
|
||||
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
|
||||
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
|
||||
defaults for malformed optional integer params.
|
||||
"""
|
||||
|
||||
query_args = args or request.args
|
||||
params: dict[str, Any] = query_args.to_dict()
|
||||
for field_name in list_fields:
|
||||
params[field_name] = query_args.getlist(field_name)
|
||||
|
||||
if use_defaults_for_malformed_ints:
|
||||
_drop_malformed_defaulted_integer_params(model, params)
|
||||
return model.model_validate(params)
|
||||
|
||||
|
||||
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
|
||||
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return
|
||||
|
||||
for name, value in list(params.items()):
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
|
||||
field = model.model_fields.get(name)
|
||||
if field is None or field.is_required():
|
||||
continue
|
||||
|
||||
property_schema = properties.get(name)
|
||||
if not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
if _nullable_property_schema(property_schema).get("type") != "integer":
|
||||
continue
|
||||
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
params.pop(name)
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
@ -239,6 +300,7 @@ __all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"query_params_from_request",
|
||||
"register_enum_models",
|
||||
"register_response_schema_model",
|
||||
"register_response_schema_models",
|
||||
|
||||
@ -51,6 +51,9 @@ from .agent import roster as agent_roster
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
agent_app_access,
|
||||
agent_app_feature,
|
||||
agent_app_workspace,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
@ -119,6 +122,7 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import tag controllers
|
||||
@ -134,6 +138,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -146,6 +151,9 @@ __all__ = [
|
||||
"activate",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_app_access",
|
||||
"agent_app_feature",
|
||||
"agent_app_workspace",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
@ -206,6 +214,9 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
|
||||
@ -1,53 +1,91 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from fields.agent_fields import (
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
|
||||
class WorkflowAgentComposerApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -55,84 +93,116 @@ class WorkflowAgentComposerValidateApi(Resource):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
|
||||
class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
|
||||
class WorkflowAgentComposerImpactApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
|
||||
class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
|
||||
class AgentAppComposerApi(Resource):
|
||||
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model: App):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
|
||||
class AgentAppComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -140,14 +210,20 @@ class AgentAppComposerValidateApi(Resource):
|
||||
def post(self, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
|
||||
class AgentAppComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
@ -4,11 +4,25 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
@ -29,6 +43,14 @@ register_schema_models(
|
||||
RosterAgentUpdatePayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
|
||||
|
||||
def _agent_roster_service() -> AgentRosterService:
|
||||
@ -37,96 +59,130 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
|
||||
@console_ns.route("/agents")
|
||||
class AgentRosterListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RosterListQuery))
|
||||
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
return dump_response(
|
||||
AgentRosterListResponse,
|
||||
_agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str):
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
|
||||
), 201
|
||||
|
||||
|
||||
@console_ns.route("/agents/invite-options")
|
||||
class AgentInviteOptionsApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
|
||||
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
return dump_response(
|
||||
AgentInviteOptionsResponse,
|
||||
_agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>")
|
||||
class AgentRosterDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.response(204, "Agent archived")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions")
|
||||
class AgentRosterVersionsApi(Resource):
|
||||
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotListResponse,
|
||||
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
|
||||
class AgentRosterVersionDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
_agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
),
|
||||
)
|
||||
|
||||
59
api/controllers/console/app/agent_app_access.py
Normal file
59
api/controllers/console/app/agent_app_access.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Agent App access & sharing endpoints (read-only workflow references).
|
||||
|
||||
An Agent App is backed by a roster Agent that workflow Agent nodes may also
|
||||
reference. This exposes the read-only "Workflow access" surface from the PRD:
|
||||
which workflow apps use this Agent, without leaking the workflows' internals.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
|
||||
|
||||
class AgentReferencingWorkflowResponse(ResponseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
app_mode: str
|
||||
workflow_id: str
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentReferencingWorkflowsResponse(ResponseModel):
|
||||
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
|
||||
class AgentAppReferencingWorkflowsResource(Resource):
|
||||
@console_ns.doc("list_agent_app_referencing_workflows")
|
||||
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Referencing workflows listed successfully",
|
||||
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
|
||||
tenant_id=tenant_id, app_id=app_model.id
|
||||
)
|
||||
return AgentReferencingWorkflowsResponse(
|
||||
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
|
||||
).model_dump(mode="json")
|
||||
93
api/controllers/console/app/agent_app_feature.py
Normal file
93
api/controllers/console/app/agent_app_feature.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Agent App presentation-feature configuration endpoint.
|
||||
|
||||
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
|
||||
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
|
||||
config) is the wrong place to configure its app-level presentation features.
|
||||
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
|
||||
opener, follow-up suggestions, citations, content moderation and speech — and
|
||||
persists them onto the app's ``app_model_config`` without touching anything the
|
||||
Soul owns.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.agent_config_entities import (
|
||||
AgentFeatureToggleConfig,
|
||||
AgentSensitiveWordAvoidanceFeatureConfig,
|
||||
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
|
||||
AgentTextToSpeechFeatureConfig,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_feature_service import AgentAppFeatureConfigService
|
||||
|
||||
|
||||
class AgentAppFeaturesPayload(BaseModel):
|
||||
"""Presentation features configurable on an Agent App.
|
||||
|
||||
All fields are optional; an omitted field is reset to its disabled/empty
|
||||
default (the config form sends the full desired feature state on save).
|
||||
"""
|
||||
|
||||
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
|
||||
suggested_questions: list[str] | None = Field(
|
||||
default=None, description="Preset questions shown alongside the opener"
|
||||
)
|
||||
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
|
||||
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
|
||||
)
|
||||
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
|
||||
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
|
||||
retriever_resource: AgentFeatureToggleConfig | None = Field(
|
||||
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
|
||||
)
|
||||
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
|
||||
default=None, description="Content moderation config"
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(console_ns, AgentAppFeaturesPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-features")
|
||||
class AgentAppFeatureConfigResource(Resource):
|
||||
@console_ns.doc("update_agent_app_features")
|
||||
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
|
||||
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"})
|
||||
319
api/controllers/console/app/agent_app_workspace.py
Normal file
319
api/controllers/console/app/agent_app_workspace.py
Normal file
@ -0,0 +1,319 @@
|
||||
"""Agent App sandbox file-system inspector (read-only).
|
||||
|
||||
Exposes the PRD "rc1-like sandbox file system, downloadable not editable" view
|
||||
for an Agent App conversation: list a directory, preview a file, or download a
|
||||
file from the conversation's shell-layer workspace. The API never touches
|
||||
shellctl directly — it resolves the conversation's sandbox ``session_id`` from
|
||||
the stored session snapshot and proxies to the agent backend's read-only
|
||||
workspace endpoints.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
from clients.agent_backend.workspace_files_client import WorkspaceDownloadResult
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_workspace_service import (
|
||||
AgentAppWorkspaceService,
|
||||
AgentWorkspaceInspectorError,
|
||||
WorkflowAgentWorkspaceService,
|
||||
)
|
||||
|
||||
|
||||
class _WorkspaceFileDownloadField(fields.Raw):
|
||||
__schema_type__ = "string"
|
||||
__schema_format__ = "binary"
|
||||
|
||||
|
||||
class AgentWorkspaceListQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class AgentWorkspaceFileQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceListQuery(BaseModel):
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceFileQuery(BaseModel):
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileEntryResponse(ResponseModel):
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResponse(ResponseModel):
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntryResponse] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class WorkspacePreviewResponse(ResponseModel):
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, WorkspaceListResponse)
|
||||
register_response_schema_models(console_ns, WorkspacePreviewResponse)
|
||||
|
||||
|
||||
def _handle(exc: Exception) -> tuple[dict[str, object], int]:
|
||||
if isinstance(exc, AgentWorkspaceInspectorError):
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
if isinstance(exc, AgentBackendHTTPError):
|
||||
detail = exc.detail
|
||||
if isinstance(detail, dict):
|
||||
return {
|
||||
"code": detail.get("code", "agent_backend_error"),
|
||||
"message": detail.get("message", str(exc)),
|
||||
}, exc.status_code
|
||||
return {"code": "agent_backend_error", "message": str(detail)}, exc.status_code
|
||||
if isinstance(exc, AgentBackendTransportError):
|
||||
return {"code": "agent_backend_unreachable", "message": str(exc)}, 502
|
||||
raise exc
|
||||
|
||||
|
||||
def _download_response(result: WorkspaceDownloadResult) -> Response | tuple[dict[str, object], int]:
|
||||
if result.truncated:
|
||||
return {
|
||||
"code": "workspace_file_too_large",
|
||||
"message": (
|
||||
"file exceeds the workspace download limit; use preview for partial text or download a smaller file"
|
||||
),
|
||||
"size": result.size,
|
||||
}, 413
|
||||
filename = result.path.rsplit("/", 1)[-1] or "download"
|
||||
return Response(
|
||||
result.content,
|
||||
mimetype="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(len(result.content)),
|
||||
"X-Workspace-File-Size": str(result.size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files")
|
||||
class AgentAppWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_agent_app_workspace_files")
|
||||
@console_ns.doc(description="List a directory in an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceListQuery)})
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceListQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/preview")
|
||||
class AgentAppWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in an Agent App conversation's sandbox workspace")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/download")
|
||||
class AgentAppWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Download a file from an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files"
|
||||
)
|
||||
class WorkflowAgentWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_workflow_agent_workspace_files")
|
||||
@console_ns.doc(description="List a directory in a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceListQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/preview"
|
||||
)
|
||||
class WorkflowAgentWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in a Workflow Agent node's sandbox workspace")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/download"
|
||||
)
|
||||
class WorkflowAgentWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Download a file from a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -25,6 +25,9 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
@ -34,8 +37,8 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from libs.login import login_required
|
||||
from models import Account, App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -55,22 +58,23 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
mode: AppListMode = Field(default=cast(AppListMode, "all"), description="App mode filter")
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
creator_ids: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@ -91,10 +95,29 @@ class AppListQuery(BaseModel):
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
@field_validator("creator_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_creator_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported creator_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in creator_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
indexed_creator_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
@ -102,12 +125,19 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
match = _CREATOR_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_creator_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
if indexed_creator_ids:
|
||||
normalized["creator_ids"] = [value for _, value in sorted(indexed_creator_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
@ -115,7 +145,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
@ -393,6 +425,8 @@ class AppDetailWithSite(AppDetail):
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
# For Agent App type: the roster Agent backing this app (None otherwise).
|
||||
bound_agent_id: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -467,10 +501,11 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
@with_session(write=False)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
@ -478,12 +513,13 @@ class AppListApi(Resource):
|
||||
mode=args.mode,
|
||||
name=args.name,
|
||||
tag_ids=args.tag_ids,
|
||||
creator_ids=args.creator_ids,
|
||||
is_created_by_me=args.is_created_by_me,
|
||||
)
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -504,7 +540,7 @@ class AppListApi(Resource):
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
@ -543,9 +579,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
@ -648,11 +685,10 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -731,7 +767,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -739,13 +776,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
redirect_url = get_redirect_url(current_user_id, claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@ -9,9 +9,11 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -48,9 +50,9 @@ class AppImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -19,7 +19,13 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -31,7 +37,7 @@ from core.helper.trace_id_helper import get_external_trace_id
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
@ -41,9 +47,24 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_debugger_chat_streaming(
|
||||
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
|
||||
) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode != "blocking"
|
||||
if response_mode_provided and response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
|
||||
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
|
||||
# debugging still pass it. Optional here, required in practice by those modes
|
||||
# downstream when their config is built from args.
|
||||
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
@ -84,7 +105,8 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -92,8 +114,6 @@ class CompletionMessageApi(Resource):
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -157,13 +176,21 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
|
||||
)
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
args["response_mode"] = "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -171,8 +198,6 @@ class ChatMessageApi(Resource):
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account or EndUser instance")
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming
|
||||
)
|
||||
@ -211,15 +236,14 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
|
||||
@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -31,8 +36,9 @@ from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -205,10 +212,10 @@ class ChatConversationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
@ -316,12 +323,13 @@ class ChatConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -332,11 +340,11 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def _get_conversation(current_user: Account, app_model, conversation_id):
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -11,6 +13,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import with_session
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -19,11 +22,11 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_generator_service import WorkflowGeneratorService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@ -41,6 +44,24 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class WorkflowGeneratePayload(BaseModel):
|
||||
"""Payload for the cmd+k `/create` and `/refine` workflow generator endpoint.
|
||||
|
||||
See ``services/workflow_generator_service.py`` for behaviour. Errors are
|
||||
surfaced through the same envelope as ``/rule-generate`` so the frontend
|
||||
can reuse its existing handler.
|
||||
"""
|
||||
|
||||
mode: Literal["workflow", "advanced-chat"] = Field(..., description="Target app mode for the generated graph")
|
||||
instruction: str = Field(..., description="Natural-language workflow description")
|
||||
ideal_output: str = Field(default="", description="Optional sample output for grounding")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
current_graph: dict | None = Field(
|
||||
default=None,
|
||||
description="Existing draft graph to refine (cmd+k `/refine`); omit for create-from-scratch",
|
||||
)
|
||||
|
||||
|
||||
register_enum_models(console_ns, LLMMode)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -49,6 +70,7 @@ register_schema_models(
|
||||
RuleStructuredOutputPayload,
|
||||
InstructionGeneratePayload,
|
||||
InstructionTemplatePayload,
|
||||
WorkflowGeneratePayload,
|
||||
ModelConfig,
|
||||
)
|
||||
|
||||
@ -158,7 +180,8 @@ class InstructionGenerateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
@with_session(write=False)
|
||||
def post(self, session: Session, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
@ -168,10 +191,10 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.get(App, args.flow_id)
|
||||
app = session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
@ -263,3 +286,56 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/workflow-generate")
|
||||
class WorkflowGenerateApi(Resource):
|
||||
"""Generate a Workflow / Chatflow draft graph from a natural-language description.
|
||||
|
||||
Triggered by the cmd+k `/create` slash command. Returns a graph payload
|
||||
shaped exactly like ``WorkflowService.sync_draft_workflow``'s input, so the
|
||||
frontend can hand it straight to ``/apps/{id}/workflows/draft``.
|
||||
"""
|
||||
|
||||
@console_ns.doc("generate_workflow_graph")
|
||||
@console_ns.doc(description="Generate a Dify workflow graph from natural language")
|
||||
@console_ns.expect(console_ns.models[WorkflowGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Workflow graph generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = WorkflowGeneratePayload.model_validate(console_ns.payload)
|
||||
|
||||
# Reject obviously-empty instructions at the boundary — Pydantic only
|
||||
# validates ``instruction`` is a str, but a whitespace-only string
|
||||
# would still hit the LLM and waste a planner+builder roundtrip on a
|
||||
# response that the postprocess validator would reject anyway.
|
||||
if not args.instruction.strip():
|
||||
return {
|
||||
"error": "Instruction is required",
|
||||
"errors": [{"code": "EMPTY_INSTRUCTION", "detail": "Instruction is required"}],
|
||||
}, 400
|
||||
|
||||
try:
|
||||
result = WorkflowGeneratorService.generate_workflow_graph(
|
||||
tenant_id=current_tenant_id,
|
||||
mode=args.mode,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
current_graph=args.current_graph,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return result
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
@ -43,7 +44,8 @@ from fields.conversation_fields import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -178,7 +180,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args.message_id)
|
||||
@ -336,9 +337,9 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
try:
|
||||
|
||||
@ -8,14 +8,20 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
created_by=current_user_id,
|
||||
updated_by=current_user_id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_by = current_user_id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -14,12 +14,14 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -85,9 +87,9 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -134,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
|
||||
@ -8,13 +8,14 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import AppMode
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
@ -48,9 +49,8 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -109,9 +109,8 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -169,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -230,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -293,10 +290,9 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
@ -374,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
@ -444,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -505,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
|
||||
@ -7,12 +7,18 @@ from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_valida
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
@ -213,9 +219,10 @@ class WorkflowCommentListApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@ -229,12 +236,14 @@ class WorkflowCommentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
@ -258,10 +267,11 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
tenant_id=current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
@ -276,12 +286,14 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -302,10 +314,12 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -327,10 +341,12 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
@ -353,11 +369,13 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
@ -386,17 +404,19 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -416,15 +436,17 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
comment_id=comment_id, tenant_id=current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
reply_id=reply_id,
|
||||
@ -448,9 +470,13 @@ class WorkflowCommentMentionUsersApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
current_tenant = current_user.current_tenant # need the tenant object here
|
||||
if current_tenant is None:
|
||||
raise ValueError("current tenant is required")
|
||||
members = TenantService.get_tenant_members(current_tenant)
|
||||
users = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersPayload(users=users)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Concatenate, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -15,7 +15,12 @@ from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -27,8 +32,8 @@ from graphon.file import helpers as file_helpers
|
||||
from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -123,14 +128,15 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
def ensure_variable_access(
|
||||
variable: WorkflowDraftVariable | None,
|
||||
app_id: str,
|
||||
variable_id: str,
|
||||
current_user_id: str,
|
||||
) -> WorkflowDraftVariable:
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_id or variable.user_id != current_user.id:
|
||||
if variable.app_id != app_id or variable.user_id != current_user_id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@ -214,7 +220,9 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -230,9 +238,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -249,7 +258,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -279,7 +288,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
def delete(self, current_user: Account, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -313,7 +322,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
def get(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -327,7 +336,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
def delete(self, current_user: Account, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id)
|
||||
@ -347,15 +356,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
def get(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -366,7 +376,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
def patch(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -394,10 +404,11 @@ class VariableApi(Resource):
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -438,15 +449,16 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
def delete(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -462,7 +474,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
def put(self, current_user: Account, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -474,10 +486,11 @@ class VariableResetApi(Resource):
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
@ -488,20 +501,20 @@ class VariableResetApi(Resource):
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(app_model: App, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
return draft_vars
|
||||
|
||||
@ -515,7 +528,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
@ -525,7 +538,7 @@ class ConversationVariableCollectionApi(Resource):
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@ -537,7 +550,8 @@ class ConversationVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
@ -564,8 +578,8 @@ class SystemVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@ -576,7 +590,7 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
def get(self, _current_user: Account, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -617,7 +631,8 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
@ -12,7 +12,12 @@ from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id
|
||||
from extensions.ext_database import db
|
||||
@ -30,8 +35,8 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.archive_storage import ArchiveStorageNotConfiguredError, get_archive_storage
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
@ -190,8 +195,8 @@ class WorkflowRunExportApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
tenant_id = app_model.tenant_id
|
||||
app_id = app_model.id
|
||||
run_id_str = str(run_id)
|
||||
|
||||
run_created_at = db.session.scalar(
|
||||
@ -397,18 +402,18 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, run_id: UUID):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id_str,
|
||||
user=user,
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
@ -432,7 +437,8 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
@ -449,7 +455,7 @@ class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
if workflow_run.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
|
||||
@ -6,10 +6,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -12,14 +12,14 @@ from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.model import App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required, with_current_tenant_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -124,18 +124,16 @@ class AppTriggersApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerListResponse.__name__])
|
||||
def get(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
@ -166,19 +164,18 @@ class AppTriggerEnableApi(Resource):
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@console_ns.response(200, "Success", console_ns.models[WorkflowTriggerResponse.__name__])
|
||||
def post(self, app_model: App):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.tenant_id == current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@ -76,7 +76,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args.language in languages:
|
||||
if args.language is not None and args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
|
||||
@ -32,11 +32,11 @@ from controllers.console.wraps import (
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
@ -46,6 +46,7 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
@ -172,9 +173,8 @@ class LoginApi(Resource):
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
|
||||
@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -32,8 +39,9 @@ class Subscription(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
@ -45,8 +53,9 @@ class Invoices(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, partner_key: str):
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
|
||||
@ -3,11 +3,18 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
|
||||
@ -1,41 +1,37 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse, TextContentResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.entities.knowledge_entities import IndexingEstimate
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import (
|
||||
integrate_fields,
|
||||
integrate_icon_fields,
|
||||
integrate_list_fields,
|
||||
integrate_notion_info_list_fields,
|
||||
integrate_page_fields,
|
||||
integrate_workspace_fields,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
@ -54,50 +50,74 @@ class DataSourceNotionPreviewQuery(BaseModel):
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
|
||||
class DataSourceIntegrateIconResponse(ResponseModel):
|
||||
type: str | None = None
|
||||
url: str | None = None
|
||||
emoji: str | None = None
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
class DataSourceIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str
|
||||
type: str
|
||||
|
||||
integrate_page_fields_copy = integrate_page_fields.copy()
|
||||
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
|
||||
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
|
||||
|
||||
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
|
||||
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
|
||||
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
|
||||
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[DataSourceIntegratePageResponse]
|
||||
total: int
|
||||
|
||||
integrate_fields_copy = integrate_fields.copy()
|
||||
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
|
||||
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
|
||||
|
||||
integrate_list_fields_copy = integrate_list_fields.copy()
|
||||
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
|
||||
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
|
||||
class DataSourceIntegrateResponse(ResponseModel):
|
||||
id: str | None
|
||||
provider: str
|
||||
created_at: datetime | int | None
|
||||
is_bound: bool
|
||||
disabled: bool | None
|
||||
link: str
|
||||
source_info: DataSourceIntegrateWorkspaceResponse | None
|
||||
|
||||
notion_page_fields = {
|
||||
"page_name": fields.String,
|
||||
"page_id": fields.String,
|
||||
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
|
||||
"is_bound": fields.Boolean,
|
||||
"parent_id": fields.String,
|
||||
"type": fields.String,
|
||||
}
|
||||
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
|
||||
@field_serializer("created_at")
|
||||
def serialize_created_at(self, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
notion_workspace_fields = {
|
||||
"workspace_name": fields.String,
|
||||
"workspace_id": fields.String,
|
||||
"workspace_icon": fields.String,
|
||||
"pages": fields.List(fields.Nested(notion_page_model)),
|
||||
}
|
||||
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
|
||||
|
||||
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
|
||||
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
|
||||
integrate_notion_info_list_model = get_or_create_model(
|
||||
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
|
||||
class DataSourceIntegrateListResponse(ResponseModel):
|
||||
data: list[DataSourceIntegrateResponse]
|
||||
|
||||
|
||||
class NotionIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str | None
|
||||
type: str
|
||||
is_bound: bool
|
||||
|
||||
|
||||
class NotionIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[NotionIntegratePageResponse]
|
||||
|
||||
|
||||
class NotionIntegrateInfoListResponse(ResponseModel):
|
||||
notion_info: list[NotionIntegrateWorkspaceResponse]
|
||||
|
||||
|
||||
register_schema_models(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DataSourceIntegrateListResponse,
|
||||
IndexingEstimate,
|
||||
NotionIntegrateInfoListResponse,
|
||||
SimpleResultResponse,
|
||||
TextContentResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -109,10 +129,9 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -154,19 +173,21 @@ class DataSourceApi(Resource):
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
return {"data": integrate_data}, 200
|
||||
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
|
||||
) -> tuple[dict[str, str], int]:
|
||||
binding_id_str = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
@ -198,12 +219,12 @@ class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_model)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -278,22 +299,22 @@ class DataSourceNotionListApi(Resource):
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
|
||||
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
|
||||
class DataSourceNotionPreviewApi(Resource):
|
||||
"""Preview one authorized Notion page through the datasource credential."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
@ -316,13 +337,18 @@ class DataSourceNotionApi(Resource):
|
||||
text_docs = extractor.extract()
|
||||
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/notion-indexing-estimate")
|
||||
class DataSourceNotionIndexingEstimateApi(Resource):
|
||||
"""Estimate indexing work for selected Notion pages."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
# validate args
|
||||
@ -355,7 +381,7 @@ class DataSourceNotionApi(Resource):
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
return dump_response(IndexingEstimate, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
@ -364,7 +390,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -382,7 +408,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -9,7 +9,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -34,16 +34,18 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
document_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentStatusListResponse,
|
||||
DocumentStatusResponse,
|
||||
normalize_enum,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@ -69,17 +71,13 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class DatasetResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -93,7 +91,7 @@ class DatasetResponse(ResponseModel):
|
||||
@field_validator("data_source_type", "indexing_technique", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
return normalize_enum(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
@ -101,61 +99,10 @@ class DatasetResponse(ResponseModel):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class DocumentResponse(ResponseModel):
|
||||
id: str
|
||||
position: int | None = None
|
||||
data_source_type: str | None = None
|
||||
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
|
||||
data_source_detail_dict: Any = None
|
||||
dataset_process_rule_id: str | None = None
|
||||
name: str
|
||||
created_from: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
tokens: int | None = None
|
||||
indexing_status: str | None = None
|
||||
error: str | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
archived: bool | None = None
|
||||
display_status: str | None = None
|
||||
word_count: int | None = None
|
||||
hit_count: int | None = None
|
||||
doc_form: str | None = None
|
||||
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
|
||||
summary_index_status: str | None = None
|
||||
need_summary: bool | None = None
|
||||
|
||||
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
@field_validator("doc_metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
process_rule_dict: Any = None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
|
||||
|
||||
class DatasetAndDocumentResponse(ResponseModel):
|
||||
@ -190,6 +137,14 @@ class DocumentDatasetListParam(BaseModel):
|
||||
fetch_val: str = Field("false", alias="fetch")
|
||||
|
||||
|
||||
class DocumentWithSegmentsListResponse(ResponseModel):
|
||||
data: list[DocumentWithSegmentsResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
@ -200,18 +155,25 @@ register_schema_models(
|
||||
GenerateSummaryPayload,
|
||||
DocumentMetadataUpdatePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultMessageResponse,
|
||||
SimpleResultResponse,
|
||||
UrlResponse,
|
||||
DatasetResponse,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
DocumentWithSegmentsListResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
def get_document(
|
||||
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
|
||||
) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -231,8 +193,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -259,8 +220,8 @@ class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -312,12 +273,17 @@ class DatasetDocumentListApi(Resource):
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Documents retrieved successfully",
|
||||
console_ns.models[DocumentWithSegmentsListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
@ -425,18 +391,15 @@ class DatasetDocumentListApi(Resource):
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
else:
|
||||
data = marshal(documents, document_fields)
|
||||
response = {
|
||||
"data": data,
|
||||
"data": documents,
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response
|
||||
return dump_response(DocumentWithSegmentsListResponse, response)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -445,8 +408,8 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -482,9 +445,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -522,9 +483,10 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -567,9 +529,7 @@ class DatasetInitApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
@ -583,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -648,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -742,12 +704,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -784,9 +750,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
@ -794,21 +759,25 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -817,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -839,7 +808,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
return marshal(document_dict, document_status_fields)
|
||||
return dump_response(DocumentStatusResponse, document_dict)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
@ -860,10 +829,12 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -949,7 +920,9 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -958,7 +931,7 @@ class DocumentApi(DocumentResource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -979,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@ -996,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
@ -1043,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["pause", "resume"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1091,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1140,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(
|
||||
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -1256,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in payload.document_ids:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
@ -1288,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1304,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(DocumentResponse, document)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
@ -1313,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -1391,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1399,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
@ -1484,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1497,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from typing import cast as type_cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -13,7 +14,12 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@ -27,6 +33,8 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@ -34,30 +42,29 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from fields.segment_fields import (
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkResponse,
|
||||
SegmentDetailResponse,
|
||||
SegmentResponse,
|
||||
segment_response_with_summary,
|
||||
segment_responses_with_summaries,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -67,6 +74,16 @@ class SegmentListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentIdListQuery(BaseModel):
|
||||
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
@ -92,13 +109,35 @@ class SegmentBatchImportStatusResponse(ResponseModel):
|
||||
job_status: str
|
||||
|
||||
|
||||
class ConsoleSegmentListResponse(ResponseModel):
|
||||
data: list[SegmentResponse]
|
||||
limit: int
|
||||
total: int
|
||||
total_pages: int
|
||||
page: int
|
||||
|
||||
|
||||
class ChildChunkBatchUpdateResponse(ResponseModel):
|
||||
data: list[ChildChunkResponse]
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
class SegmentDocParams:
|
||||
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
|
||||
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
|
||||
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
|
||||
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
|
||||
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SegmentListQuery,
|
||||
SegmentIdListQuery,
|
||||
ChildChunkListQuery,
|
||||
SegmentCreatePayload,
|
||||
SegmentUpdatePayload,
|
||||
BatchImportPayload,
|
||||
@ -107,17 +146,30 @@ register_schema_models(
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SegmentResponse,
|
||||
ConsoleSegmentListResponse,
|
||||
SegmentDetailResponse,
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkBatchUpdateResponse,
|
||||
SegmentBatchImportStatusResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
|
||||
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -134,12 +186,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
**request.args.to_dict(),
|
||||
"status": request.args.getlist("status"),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -169,12 +216,17 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Guard with jsonb_typeof to avoid "cannot extract elements from a scalar" error
|
||||
# when keywords is null or a non-array JSON value.
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.where(func.jsonb_typeof(cast(DocumentSegment.keywords, JSONB)) == "array")
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
@ -200,42 +252,33 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
segment_list = list(segments.items)
|
||||
segment_ids = [segment.id for segment in segment_list]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": segments_with_summary,
|
||||
"data": segment_responses_with_summaries(segment_list, summaries),
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(ConsoleSegmentListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -263,15 +306,24 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["enable", "disable"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -316,11 +368,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
SegmentService.update_segments_status(segment_ids, action, dataset, document)
|
||||
except Exception as e:
|
||||
raise InvalidActionError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -328,9 +381,10 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -367,21 +421,30 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -435,16 +498,24 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -490,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -518,11 +589,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
try:
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id),
|
||||
job_id,
|
||||
upload_file_id,
|
||||
dataset_id_str,
|
||||
document_id_str,
|
||||
@ -531,7 +602,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
|
||||
|
||||
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@ -546,11 +617,13 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
response = {"job_id": job_id, "job_status": cache_result.decode()}
|
||||
return dump_response(SegmentBatchImportStatusResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -558,9 +631,12 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -608,14 +684,16 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
|
||||
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -637,13 +715,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -652,22 +724,32 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
response = {
|
||||
"data": child_chunks.items,
|
||||
"total": child_chunks.total,
|
||||
"total_pages": child_chunks.pages,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}, 200
|
||||
}
|
||||
return dump_response(ChildChunkListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Child chunks updated successfully",
|
||||
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -702,7 +784,7 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -713,10 +795,19 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -743,7 +834,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -770,10 +861,20 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -800,7 +901,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -822,4 +923,4 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -29,7 +30,8 @@ from fields.dataset_fields import (
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -9,11 +9,18 @@ from configs import dify_config
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -195,15 +200,17 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, user: Account, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
user=user,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
@ -216,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -241,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -264,9 +269,8 @@ class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -277,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -292,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -310,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
@ -330,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -352,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@ -2,13 +2,12 @@ from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -30,13 +29,11 @@ class DataSourceContentPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
|
||||
inputs = args.inputs
|
||||
|
||||
@ -1,13 +1,20 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,79 +23,132 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineTemplateListQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
language: str = Field(default="en-US", description="Template language")
|
||||
|
||||
|
||||
class PipelineTemplateDetailQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
|
||||
|
||||
class PipelineTemplateItemResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon: dict[str, Any]
|
||||
description: str
|
||||
position: int
|
||||
chunk_structure: str
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
|
||||
|
||||
class PipelineTemplateListResponse(ResponseModel):
|
||||
pipeline_templates: list[PipelineTemplateItemResponse]
|
||||
|
||||
|
||||
class PipelineTemplateDetailResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon_info: dict[str, Any]
|
||||
description: str
|
||||
chunk_structure: str
|
||||
export_data: str
|
||||
graph: dict[str, Any]
|
||||
created_by: str | None = None
|
||||
|
||||
|
||||
class CustomizedPipelineTemplatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] = Field(default_factory=lambda: IconInfo(icon="").model_dump())
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CustomizedPipelineTemplatePayload,
|
||||
PipelineTemplateDetailQuery,
|
||||
PipelineTemplateListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
PipelineTemplateDetailResponse,
|
||||
PipelineTemplateListResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateListQuery))
|
||||
@console_ns.response(200, "Pipeline templates", console_ns.models[PipelineTemplateListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
def get(self) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
|
||||
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateDetailQuery))
|
||||
@console_ns.response(200, "Pipeline template", console_ns.models[PipelineTemplateDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
def get(self, template_id: str) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateDetailQuery.model_validate(request.args.to_dict(flat=True))
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, query.type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
register_response_schema_models(console_ns, SimpleDataResponse)
|
||||
raise NotFound("Pipeline template not found from upstream service.")
|
||||
return dump_response(PipelineTemplateDetailResponse, pipeline_template), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template updated")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def patch(self, template_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(204, "Pipeline template deleted")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
def delete(self, template_id: str) -> tuple[str, int]:
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
|
||||
def post(self, template_id: str):
|
||||
def post(self, template_id: str) -> JsonResponseWithStatus:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||
@ -96,19 +156,20 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": template.yaml_content}), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template published")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def post(self, pipeline_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
return {"result": "success"}
|
||||
return "", 204
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import JsonResponseWithStatus, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import RagPipelineImportResponse
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -25,19 +30,26 @@ class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_schema_models(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_response_schema_models(console_ns, DatasetDetailResponse, RagPipelineImportResponse)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"RAG pipeline dataset import started",
|
||||
console_ns.models[RagPipelineImportResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -70,19 +82,20 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
return dump_response(RagPipelineImportResponse, import_info), 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@console_ns.response(201, "RAG pipeline dataset created", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
@ -99,4 +112,4 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -21,7 +22,7 @@ from controllers.console.app.workflow_draft_variable import (
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -29,7 +30,7 @@ from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -57,7 +58,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -72,10 +75,12 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -84,7 +89,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
@ -112,7 +117,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -143,7 +148,7 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
def get(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -154,7 +159,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
def delete(self, current_user: Account, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id)
|
||||
@ -169,7 +174,7 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -184,7 +189,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def patch(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -253,7 +258,7 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def delete(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -271,7 +276,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def put(self, _current_user: Account, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -297,17 +302,17 @@ class RagPipelineVariableResetApi(Resource):
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
def _get_variable_list(pipeline: Pipeline, node_id: str, current_user_id: str) -> WorkflowDraftVariableList:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user_id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user_id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id)
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
@ -315,14 +320,14 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
def get(self, current_user: Account, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID, current_user.id)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables")
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, _current_user: Account, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
|
||||
@ -1,23 +1,29 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import (
|
||||
leaked_dependency_fields,
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -36,35 +42,45 @@ class RagPipelineImportPayload(BaseModel):
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
include_secret: str = Field(default="false")
|
||||
include_secret: str = Field(default="false", description="Whether to include secret values in the exported DSL")
|
||||
|
||||
|
||||
class RagPipelineImportResponse(ResponseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: str | None = None
|
||||
dataset_id: str | None = None
|
||||
current_dsl_version: str
|
||||
imported_dsl_version: str
|
||||
error: str = ""
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesResponse(ResponseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
|
||||
|
||||
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
|
||||
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
|
||||
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
|
||||
fields.Nested(leaked_dependency_model)
|
||||
)
|
||||
pipeline_import_check_dependencies_model = get_or_create_model(
|
||||
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
RagPipelineImportCheckDependenciesResponse,
|
||||
RagPipelineImportResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account) -> JsonResponseWithStatus:
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Use a plain Session so that caught exceptions inside the service
|
||||
@ -91,23 +107,23 @@ class RagPipelineImportApi(Resource):
|
||||
status = result.status
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return dump_response(RagPipelineImportResponse, result), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
account = current_user
|
||||
@ -119,34 +135,40 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dependencies checked",
|
||||
console_ns.models[RagPipelineImportCheckDependenciesResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportCheckDependenciesResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(IncludeSecretQuery))
|
||||
@console_ns.response(200, "Pipeline exported", console_ns.models[SimpleDataResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@ -156,4 +178,4 @@ class RagPipelineExportApi(Resource):
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": result}), 200
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.scalar(
|
||||
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
||||
@ -8,10 +8,11 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from constants.languages import languages
|
||||
from controllers.common.schema import query_params_from_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from controllers.console.wraps import account_initialization_required, with_current_user
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
|
||||
|
||||
@ -79,13 +80,14 @@ class RecommendedAppListApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RecommendedAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
elif current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
@ -2,13 +2,36 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
from services.feature_service import (
|
||||
FeatureModel,
|
||||
FeatureService,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
)
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
class TrialModelsResponse(ResponseModel):
|
||||
trial_models: list[str]
|
||||
|
||||
|
||||
class AppDslVersionResponse(ResponseModel):
|
||||
app_dsl_version: str
|
||||
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AppDslVersionResponse,
|
||||
FeatureModel,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
TrialModelsResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -54,6 +77,43 @@ class FeatureVectorSpaceApi(Resource):
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/trial-models")
|
||||
class TrialModelsApi(Resource):
|
||||
@console_ns.doc("get_trial_models")
|
||||
@console_ns.doc(description="Get hosted trial model provider configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[TrialModelsResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Get hosted trial model provider configuration for model-provider pages."""
|
||||
return dump_response(
|
||||
TrialModelsResponse,
|
||||
{"trial_models": FeatureService.get_trial_models()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app-dsl-version")
|
||||
class AppDslVersionApi(Resource):
|
||||
@console_ns.doc("get_app_dsl_version")
|
||||
@console_ns.doc(description="Get current app DSL version")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[AppDslVersionResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
"""Get current app DSL version for workflow clipboard compatibility."""
|
||||
return dump_response(
|
||||
AppDslVersionResponse,
|
||||
{"app_dsl_version": FeatureService.get_app_dsl_version()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
|
||||
@ -12,8 +12,15 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
model_validate,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -22,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
@ -33,6 +40,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(console_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
@ -45,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
|
||||
"""Ensure a console form token resolves only inside the current tenant."""
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@ -59,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
@ -70,13 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(
|
||||
self,
|
||||
payload: HumanInputFormSubmitPayload,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
form_token: str,
|
||||
):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@ -90,15 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
@ -122,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@ -130,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
164
api/controllers/console/snippets/payloads.py
Normal file
164
api/controllers/console/snippets/payloads.py
Normal file
@ -0,0 +1,164 @@
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Filter by creator account IDs",
|
||||
validation_alias=AliasChoices("creators", "creator_id"),
|
||||
)
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
return cls._normalize_string_list(value)
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def parse_tag_ids(cls, value: object) -> list[str] | None:
|
||||
"""Normalize and validate tag IDs from query string or list input."""
|
||||
items = cls._normalize_string_list(value)
|
||||
if not items:
|
||||
return None
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string_list(value: object) -> list[str] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(item).strip() for item in value if str(item).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
678
api/controllers/console/snippets/snippet_workflow.py
Normal file
678
api/controllers/console/snippets/snippet_workflow.py
Normal file
@ -0,0 +1,678 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowResponse,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.workflow_run_fields import (
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
)
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
|
||||
|
||||
def _snippet_session_maker() -> sessionmaker[Session]:
|
||||
return sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
return SnippetService(_snippet_session_maker())
|
||||
|
||||
|
||||
class SnippetWorkflowResponse(WorkflowResponse):
|
||||
input_fields: list[dict] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
SnippetWorkflowListQuery,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SnippetWorkflowResponse,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowRunPaginationResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
snippet = snippet_service.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow retrieved successfully",
|
||||
console_ns.models[SnippetWorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
workflow.conversation_variables = []
|
||||
response = SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
response["input_fields"] = snippet.input_fields_list
|
||||
return response
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
input_fields=payload.input_fields,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflow retrieved successfully",
|
||||
console_ns.models[SnippetWorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
return None
|
||||
|
||||
response = SnippetWorkflowResponse.model_validate(workflow, from_attributes=True).model_dump(mode="json")
|
||||
response["input_fields"] = snippet.input_fields_list
|
||||
return response
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = _snippet_service()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
class SnippetPublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
@console_ns.doc("get_all_snippet_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for a snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflows retrieved successfully",
|
||||
console_ns.models[WorkflowPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get all published workflow versions for snippet."""
|
||||
args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": args.page,
|
||||
"limit": args.limit,
|
||||
"has_more": has_more,
|
||||
},
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
|
||||
class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.doc("restore_snippet_workflow_to_draft")
|
||||
@console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
|
||||
@console_ns.response(200, "Workflow restored successfully")
|
||||
@console_ns.response(400, "Source workflow must be published")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, workflow_id: str):
|
||||
"""Restore a published snippet workflow version into the draft workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = _snippet_service()
|
||||
|
||||
try:
|
||||
workflow = snippet_service.restore_published_workflow_to_draft(
|
||||
snippet=snippet,
|
||||
workflow_id=workflow_id,
|
||||
account=current_user,
|
||||
)
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow runs retrieved successfully",
|
||||
console_ns.models[WorkflowRunPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow run detail retrieved successfully",
|
||||
console_ns.models[WorkflowRunDetailResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Node executions retrieved successfully",
|
||||
console_ns.models[WorkflowRunNodeExecutionListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionListResponse.model_validate(
|
||||
{"data": node_executions}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(
|
||||
200, "Node run completed successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(
|
||||
workflow_node_execution, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(
|
||||
200, "Node last run retrieved successfully", console_ns.models[WorkflowRunNodeExecutionResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(node_exec, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
node_id=node_id,
|
||||
args=args,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
session_maker=_snippet_session_maker(),
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
@ -0,0 +1,334 @@
|
||||
"""
|
||||
Snippet draft workflow variable APIs.
|
||||
|
||||
Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
(`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
WorkflowDraftVariableListQuery,
|
||||
WorkflowDraftVariableUpdatePayload,
|
||||
ensure_variable_access,
|
||||
validate_node_id,
|
||||
workflow_draft_variable_list_model,
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
workflow_draft_variable_model,
|
||||
)
|
||||
from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from graphon.variables.types import SegmentType
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
{SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
)
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
def _snippet_service() -> SnippetService:
|
||||
return SnippetService(sessionmaker(bind=db.engine, expire_on_commit=False))
|
||||
|
||||
|
||||
def _ensure_snippet_draft_variable_row_allowed(
|
||||
*,
|
||||
variable: WorkflowDraftVariable,
|
||||
variable_id: str,
|
||||
) -> None:
|
||||
"""Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
def _snippet_draft_var_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@wraps(f)
|
||||
def wrapper(self: T, current_user: Account, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, current_user, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
@console_ns.doc("get_snippet_workflow_variables")
|
||||
@console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow variables retrieved successfully",
|
||||
workflow_draft_variable_list_without_value_model,
|
||||
)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
snippet_service = _snippet_service()
|
||||
if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=snippet.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
user_id=current_user.id,
|
||||
exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class SnippetNodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_snippet_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
class SnippetVariableApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
new_name = args_model.name
|
||||
raw_value = args_model.value
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(
|
||||
mapping=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(
|
||||
mappings=raw_value,
|
||||
tenant_id=snippet.tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def delete(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class SnippetVariableResetApi(Resource):
|
||||
@console_ns.doc("reset_snippet_workflow_variable")
|
||||
@console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def put(self, current_user: Account, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
snippet_service = _snippet_service()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
)
|
||||
variable = ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=snippet.id,
|
||||
variable_id=variable_id,
|
||||
current_user_id=current_user.id,
|
||||
)
|
||||
_ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
class SnippetConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_conversation_variables")
|
||||
@console_ns.doc(
|
||||
description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
class SnippetSystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_system_variables")
|
||||
@console_ns.doc(
|
||||
description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
)
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@_snippet_draft_var_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_snippet_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@_snippet_draft_var_prerequisite
|
||||
def get(self, _current_user: Account, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
snippet_service = _snippet_service()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars_list: list[dict[str, Any]] = []
|
||||
for v in workflow.environment_variables:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.exposed_type().value,
|
||||
"value": v.value,
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
@ -51,7 +51,7 @@ class TagBindingRemovePayload(BaseModel):
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
type: Literal["knowledge", "app", "snippet", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
|
||||
|
||||
@ -96,7 +96,10 @@ class TagListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
params={
|
||||
"type": 'Tag type filter. Can be "knowledge", "app", or "snippet".',
|
||||
"keyword": "Search keyword for tag name.",
|
||||
}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
@with_current_tenant_id
|
||||
|
||||
@ -18,7 +18,7 @@ from controllers.common.fields import (
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -42,15 +42,17 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -173,7 +175,6 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AccountInitPayload,
|
||||
AccountNamePayload,
|
||||
AccountAvatarPayload,
|
||||
@ -245,6 +246,7 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
@ -329,20 +330,21 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
@ -355,15 +357,15 @@ class AccountAvatarApi(Resource):
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
account_integrates = db.session.scalars(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||
).all()
|
||||
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_deletion_verification_email(account, code)
|
||||
|
||||
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
@ -563,9 +561,8 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
@ -577,9 +574,8 @@ class EducationApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
|
||||
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||
|
||||
@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
@ -96,17 +102,15 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the injected workspace and user."""
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, user_id: str, id: str):
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, user_id: str, id: str):
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from libs.login import login_required
|
||||
from models import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user