mirror of
https://github.com/langgenius/dify.git
synced 2026-05-20 08:46:57 +08:00
Compare commits
93 Commits
deploy/dev
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cf2612ea5 | |||
| 05408af8a1 | |||
| d3ae074456 | |||
| 0b48a7e991 | |||
| 809f513ccb | |||
| d9e90d0fa0 | |||
| d1417bbe4b | |||
| 2565637e36 | |||
| cae9923e5a | |||
| a328bbbced | |||
| 5276eb689b | |||
| 4b2badb6f2 | |||
| 34a89416f7 | |||
| a13ab76002 | |||
| b04b4449db | |||
| 674cdc3521 | |||
| 2031d31ee8 | |||
| 04d62867af | |||
| 7f392b6950 | |||
| b0a3399774 | |||
| 2d5186fb28 | |||
| 06f076e0ff | |||
| 5b79f7e99d | |||
| 1cee1a25b6 | |||
| c0f237bf35 | |||
| 75d7fc0526 | |||
| c057b5c5ff | |||
| 5468c4ec96 | |||
| f4c02e4c6b | |||
| 9dc95eeb20 | |||
| 76bba64b79 | |||
| 59e96fbb2a | |||
| 06ea0f7ac2 | |||
| 730a0bef9e | |||
| 2eb37caf2e | |||
| 7e8147295b | |||
| c07686928a | |||
| d1238180ed | |||
| 969760364d | |||
| ceabfeb3a7 | |||
| c407f40e0d | |||
| 28818f2e2a | |||
| e2c52c9b0f | |||
| 1925d58369 | |||
| b79fc5d6b4 | |||
| 6649e4025e | |||
| b96f372f45 | |||
| 127fbf2c9a | |||
| 3c70d28064 | |||
| cd4d6f8a22 | |||
| 9d0906c684 | |||
| 41b6f894c0 | |||
| e7e6fe8813 | |||
| c0bdd6792f | |||
| 27b084c4d4 | |||
| 3f7a68fc77 | |||
| a252fbddfa | |||
| ff02636a4b | |||
| 63946d829e | |||
| cdcfd2ef2c | |||
| b04a3851cc | |||
| b41338cd08 | |||
| 28153df4d3 | |||
| 3bc3386535 | |||
| 7654f14241 | |||
| 194b54bae4 | |||
| 0e16d36edb | |||
| 432a6412a3 | |||
| 55d05fe52d | |||
| 0d500e6965 | |||
| 5798610f27 | |||
| a35b28dbef | |||
| 1a4288c811 | |||
| 9dc32f2318 | |||
| 7210f856c9 | |||
| ebcc1200a3 | |||
| e660d7af38 | |||
| d9ccfcbc6e | |||
| a9bcec013f | |||
| aeb7687e2c | |||
| 9355d36718 | |||
| a03ee828a3 | |||
| 7066372892 | |||
| 55f95dbc36 | |||
| 8b40de3c4e | |||
| af4b9bfa8f | |||
| b9e3130388 | |||
| 12d33652b6 | |||
| fe8cf2aff4 | |||
| d1d190374d | |||
| e1be4e6aa8 | |||
| 301a470e7a | |||
| 91251ad5a5 |
@ -12,7 +12,7 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
|
||||
- Use Tailwind CSS v4.1+ rules via the `tailwind-css-rules` skill. Prefer v4 utilities, `gap`, `text-size/line-height`, `min-h-dvh`, and avoid deprecated utilities and `@apply`.
|
||||
- Follow Dify's CSS-first Tailwind v4 contract from `packages/dify-ui/README.md` and `packages/dify-ui/AGENTS.md`. Prefer design-system tokens, utilities, and radius mappings over generic Tailwind guidance.
|
||||
|
||||
## Ownership
|
||||
|
||||
|
||||
@ -1,367 +0,0 @@
|
||||
---
|
||||
name: tailwind-css-rules
|
||||
description: Tailwind CSS v4.1+ rules and best practices. Use when writing, reviewing, refactoring, or upgrading Tailwind CSS classes and styles, especially v4 utility migrations, layout spacing, typography, responsive variants, dark mode, gradients, CSS variables, and component styling.
|
||||
---
|
||||
|
||||
# Tailwind CSS Rules and Best Practices
|
||||
|
||||
## Core Principles
|
||||
|
||||
- **Always use Tailwind CSS v4.1+** - Ensure the codebase is using the latest version
|
||||
- **Do not use deprecated or removed utilities** - ALWAYS use the replacement
|
||||
- **Never use `@apply`** - Use CSS variables, the `--spacing()` function, or framework components instead
|
||||
- **Check for redundant classes** - Remove any classes that aren't necessary
|
||||
- **Group elements logically** to simplify responsive tweaks later
|
||||
|
||||
## Upgrading to Tailwind CSS v4
|
||||
|
||||
### Before Upgrading
|
||||
|
||||
- **Always read the upgrade documentation first** - Read https://tailwindcss.com/docs/upgrade-guide and https://tailwindcss.com/blog/tailwindcss-v4 before starting an upgrade.
|
||||
- Ensure the git repository is in a clean state before starting
|
||||
|
||||
### Upgrade Process
|
||||
|
||||
1. Run the upgrade command: `npx @tailwindcss/upgrade@latest` for both major and minor updates
|
||||
2. The tool will convert JavaScript config files to the new CSS format
|
||||
3. Review all changes extensively to clean up any false positives
|
||||
4. Test thoroughly across your application
|
||||
|
||||
## Breaking Changes Reference
|
||||
|
||||
### Removed Utilities (NEVER use these in v4)
|
||||
|
||||
| ❌ Deprecated | ✅ Replacement |
|
||||
| ----------------------- | ------------------------------------------------- |
|
||||
| `bg-opacity-*` | Use opacity modifiers like `bg-black/50` |
|
||||
| `text-opacity-*` | Use opacity modifiers like `text-black/50` |
|
||||
| `border-opacity-*` | Use opacity modifiers like `border-black/50` |
|
||||
| `divide-opacity-*` | Use opacity modifiers like `divide-black/50` |
|
||||
| `ring-opacity-*` | Use opacity modifiers like `ring-black/50` |
|
||||
| `placeholder-opacity-*` | Use opacity modifiers like `placeholder-black/50` |
|
||||
| `flex-shrink-*` | `shrink-*` |
|
||||
| `flex-grow-*` | `grow-*` |
|
||||
| `overflow-ellipsis` | `text-ellipsis` |
|
||||
| `decoration-slice` | `box-decoration-slice` |
|
||||
| `decoration-clone` | `box-decoration-clone` |
|
||||
|
||||
### Renamed Utilities
|
||||
|
||||
Use the v4 name when migrating code that still carries Tailwind v3 semantics. Do not blanket-replace existing v4 classes: classes such as `rounded-sm`, `shadow-sm`, `ring-1`, and `ring-2` are valid in this codebase when they intentionally represent the current design scale.
|
||||
|
||||
| ❌ v3 pattern | ✅ v4 pattern |
|
||||
| ------------------- | -------------------------------------------------- |
|
||||
| `bg-gradient-*` | `bg-linear-*` |
|
||||
| old shadow scale | verify against the current Tailwind/design scale |
|
||||
| old blur scale | verify against the current Tailwind/design scale |
|
||||
| old radius scale | use the Dify radius token mapping when applicable |
|
||||
| `outline-none` | `outline-hidden` |
|
||||
| bare `ring` utility | use an explicit ring width such as `ring-1`/`ring-2`/`ring-3` |
|
||||
|
||||
For Figma radius tokens, follow `packages/dify-ui/AGENTS.md`. For example, `--radius/xs` maps to `rounded-sm`; do not rewrite it to `rounded-xs`.
|
||||
|
||||
## Layout and Spacing Rules
|
||||
|
||||
### Flexbox and Grid Spacing
|
||||
|
||||
#### Always use gap utilities for internal spacing
|
||||
|
||||
Gap provides consistent spacing without edge cases (no extra space on last items). It's cleaner and more maintainable than margins on children.
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<div class="flex">
|
||||
<div class="mr-4">Item 1</div>
|
||||
<div class="mr-4">Item 2</div>
|
||||
<div>Item 3</div>
|
||||
<!-- No margin on last -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<div class="flex gap-4">
|
||||
<div>Item 1</div>
|
||||
<div>Item 2</div>
|
||||
<div>Item 3</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
#### Gap vs Space utilities
|
||||
|
||||
- **Never use `space-x-*` or `space-y-*` in flex/grid layouts** - always use gap
|
||||
- Space utilities add margins to children and have issues with wrapped items
|
||||
- Gap works correctly with flex-wrap and all flex directions
|
||||
|
||||
```html
|
||||
<!-- ❌ Avoid space utilities in flex containers -->
|
||||
<div class="flex flex-wrap space-x-4">
|
||||
<!-- Space utilities break with wrapped items -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Use gap for consistent spacing -->
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<!-- Gap works perfectly with wrapping -->
|
||||
</div>
|
||||
```
|
||||
|
||||
### General Spacing Guidelines
|
||||
|
||||
- **Prefer top and left margins** over bottom and right margins (unless conditionally rendered)
|
||||
- **Use padding on parent containers** instead of bottom margins on the last child
|
||||
- **Always use `min-h-dvh` instead of `min-h-screen`** - `min-h-screen` is buggy on mobile Safari
|
||||
- **Prefer `size-*` utilities** over separate `w-*` and `h-*` when setting equal dimensions
|
||||
- For max-widths, prefer the container scale (e.g., `max-w-2xs` over `max-w-72`)
|
||||
|
||||
## Typography Rules
|
||||
|
||||
### Line Heights
|
||||
|
||||
- **Never use `leading-*` classes** - Always use line height modifiers with text size
|
||||
- **Always use fixed line heights from the spacing scale** - Don't use named values
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<p class="text-base leading-7">Text with separate line height</p>
|
||||
<p class="text-lg leading-relaxed">Text with named line height</p>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<p class="text-base/7">Text with line height modifier</p>
|
||||
<p class="text-lg/8">Text with specific line height</p>
|
||||
```
|
||||
|
||||
### Font Size Reference
|
||||
|
||||
Be precise with font sizes - know the actual pixel values:
|
||||
|
||||
- `text-xs` = 12px
|
||||
- `text-sm` = 14px
|
||||
- `text-base` = 16px
|
||||
- `text-lg` = 18px
|
||||
- `text-xl` = 20px
|
||||
|
||||
## Color and Opacity
|
||||
|
||||
### Opacity Modifiers
|
||||
|
||||
**Never use `bg-opacity-*`, `text-opacity-*`, etc.** - use the opacity modifier syntax:
|
||||
|
||||
```html
|
||||
<!-- ❌ Don't do this -->
|
||||
<div class="bg-red-500 bg-opacity-60">Old opacity syntax</div>
|
||||
|
||||
<!-- ✅ Do this instead -->
|
||||
<div class="bg-red-500/60">Modern opacity syntax</div>
|
||||
```
|
||||
|
||||
## Responsive Design
|
||||
|
||||
### Breakpoint Optimization
|
||||
|
||||
- **Check for redundant classes across breakpoints**
|
||||
- **Only add breakpoint variants when values change**
|
||||
|
||||
```html
|
||||
<!-- ❌ Redundant breakpoint classes -->
|
||||
<div class="px-4 md:px-4 lg:px-4">
|
||||
<!-- md:px-4 and lg:px-4 are redundant -->
|
||||
</div>
|
||||
|
||||
<!-- ✅ Efficient breakpoint usage -->
|
||||
<div class="px-4 lg:px-8">
|
||||
<!-- Only specify when value changes -->
|
||||
</div>
|
||||
```
|
||||
|
||||
## Dark Mode
|
||||
|
||||
### Dark Mode Best Practices
|
||||
|
||||
- Use the plain `dark:` variant pattern
|
||||
- Put light mode styles first, then dark mode styles
|
||||
- Ensure `dark:` variant comes before other variants
|
||||
|
||||
```html
|
||||
<!-- ✅ Correct dark mode pattern -->
|
||||
<div class="bg-white text-black dark:bg-black dark:text-white">
|
||||
<button class="hover:bg-gray-100 dark:hover:bg-gray-800">Click me</button>
|
||||
</div>
|
||||
```
|
||||
|
||||
## Gradient Utilities
|
||||
|
||||
- **ALWAYS Use `bg-linear-*` instead of `bg-gradient-*` utilities** - The gradient utilities were renamed in v4
|
||||
- Use the new `bg-radial` or `bg-radial-[<position>]` to create radial gradients
|
||||
- Use the new `bg-conic` or `bg-conic-*` to create conic gradients
|
||||
|
||||
```html
|
||||
<!-- ✅ Use the new gradient utilities -->
|
||||
<div class="h-14 bg-linear-to-br from-violet-500 to-fuchsia-500"></div>
|
||||
<div
|
||||
class="size-18 bg-radial-[at_50%_75%] from-sky-200 via-blue-400 to-indigo-900 to-90%"
|
||||
></div>
|
||||
<div
|
||||
class="size-24 bg-conic-180 from-indigo-600 via-indigo-50 to-indigo-600"
|
||||
></div>
|
||||
|
||||
<!-- ❌ Do not use bg-gradient-* utilities -->
|
||||
<div class="h-14 bg-gradient-to-br from-violet-500 to-fuchsia-500"></div>
|
||||
```
|
||||
|
||||
## Working with CSS Variables
|
||||
|
||||
### Accessing Theme Values
|
||||
|
||||
Tailwind CSS v4 exposes all theme values as CSS variables:
|
||||
|
||||
```css
|
||||
/* Access colors, and other theme values */
|
||||
.custom-element {
|
||||
background: var(--color-red-500);
|
||||
border-radius: var(--radius-lg);
|
||||
}
|
||||
```
|
||||
|
||||
### The `--spacing()` Function
|
||||
|
||||
Use the dedicated `--spacing()` function for spacing calculations:
|
||||
|
||||
```css
|
||||
.custom-class {
|
||||
margin-top: calc(100vh - --spacing(16));
|
||||
}
|
||||
```
|
||||
|
||||
### Extending theme values
|
||||
|
||||
Use CSS to extend theme values:
|
||||
|
||||
```css
|
||||
@import "tailwindcss";
|
||||
|
||||
@theme {
|
||||
--color-mint-500: oklch(0.72 0.11 178);
|
||||
}
|
||||
```
|
||||
|
||||
```html
|
||||
<div class="bg-mint-500">
|
||||
<!-- ... -->
|
||||
</div>
|
||||
```
|
||||
|
||||
## New v4 Features
|
||||
|
||||
### Container Queries
|
||||
|
||||
Use the `@container` class and size variants:
|
||||
|
||||
```html
|
||||
<article class="@container">
|
||||
<div class="flex flex-col @md:flex-row @lg:gap-8">
|
||||
<img class="w-full @md:w-48" />
|
||||
<div class="mt-4 @md:mt-0">
|
||||
<!-- Content adapts to container size -->
|
||||
</div>
|
||||
</div>
|
||||
</article>
|
||||
```
|
||||
|
||||
### Container Query Units
|
||||
|
||||
Use container-based units like `cqw` for responsive sizing:
|
||||
|
||||
```html
|
||||
<div class="@container">
|
||||
<h1 class="text-[50cqw]">Responsive to container width</h1>
|
||||
</div>
|
||||
```
|
||||
|
||||
### Text Shadows (v4.1)
|
||||
|
||||
Use text-shadow-\* utilities from text-shadow-2xs to text-shadow-lg:
|
||||
|
||||
```html
|
||||
<!-- ✅ Text shadow examples -->
|
||||
<h1 class="text-shadow-lg">Large shadow</h1>
|
||||
<p class="text-shadow-sm/50">Small shadow with opacity</p>
|
||||
```
|
||||
|
||||
### Masking (v4.1)
|
||||
|
||||
Use the new composable mask utilities for image and gradient masks:
|
||||
|
||||
```html
|
||||
<!-- ✅ Linear gradient masks on specific sides -->
|
||||
<div class="mask-t-from-50%">Top fade</div>
|
||||
<div class="mask-b-from-20% mask-b-to-80%">Bottom gradient</div>
|
||||
<div class="mask-linear-from-white mask-linear-to-black/60">
|
||||
Fade from white to black
|
||||
</div>
|
||||
|
||||
<!-- ✅ Radial gradient masks -->
|
||||
<div class="mask-radial-[100%_100%] mask-radial-from-75% mask-radial-at-left">
|
||||
Radial mask
|
||||
</div>
|
||||
```
|
||||
|
||||
## Component Patterns
|
||||
|
||||
### Avoiding Utility Inheritance
|
||||
|
||||
Don't add utilities to parents that you override in children:
|
||||
|
||||
```html
|
||||
<!-- ❌ Avoid this pattern -->
|
||||
<div class="text-center">
|
||||
<h1>Centered Heading</h1>
|
||||
<div class="text-left">Left-aligned content</div>
|
||||
</div>
|
||||
|
||||
<!-- ✅ Better approach -->
|
||||
<div>
|
||||
<h1 class="text-center">Centered Heading</h1>
|
||||
<div>Left-aligned content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### Component Extraction
|
||||
|
||||
- Extract repeated patterns into framework components, not CSS classes
|
||||
- Keep utility classes in templates/JSX
|
||||
- Use data attributes for complex state-based styling
|
||||
|
||||
## CSS Best Practices
|
||||
|
||||
### Nesting Guidelines
|
||||
|
||||
- Use nesting when styling both parent and children
|
||||
- Avoid empty parent selectors
|
||||
|
||||
```css
|
||||
/* ✅ Good nesting - parent has styles */
|
||||
.card {
|
||||
padding: --spacing(4);
|
||||
|
||||
> .card-title {
|
||||
font-weight: bold;
|
||||
}
|
||||
}
|
||||
|
||||
/* ❌ Avoid empty parents */
|
||||
ul {
|
||||
> li {
|
||||
/* Parent has no styles */
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
1. **Using old opacity utilities** - Always use `/opacity` syntax like `bg-red-500/60`
|
||||
2. **Redundant breakpoint classes** - Only specify changes
|
||||
3. **Space utilities in flex/grid** - Always use gap
|
||||
4. **Leading utilities** - Use line-height modifiers like `text-sm/6`
|
||||
5. **Arbitrary values** - Use the design scale
|
||||
6. **@apply directive** - Use components or CSS variables
|
||||
7. **min-h-screen on mobile** - Use min-h-dvh
|
||||
8. **Separate width/height** - Use size utilities when equal
|
||||
9. **Arbitrary values** - Always use Tailwind's predefined scale whenever possible (e.g., use `ml-4` over `ml-[16px]`)
|
||||
@ -1,5 +1,6 @@
|
||||
[run]
|
||||
omit =
|
||||
api/conftest.py
|
||||
api/tests/*
|
||||
api/migrations/*
|
||||
api/core/rag/datasource/vdb/*
|
||||
|
||||
60
.github/CODEOWNERS
vendored
60
.github/CODEOWNERS
vendored
@ -4,7 +4,7 @@
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
* @crazywoola @laipz8200
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
@ -85,39 +85,39 @@
|
||||
/api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
/api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
/api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
/api/core/plugin/ @WH-2099
|
||||
/api/services/plugin/ @WH-2099
|
||||
/api/controllers/console/workspace/plugin.py @WH-2099
|
||||
/api/controllers/inner_api/plugin/ @WH-2099
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @WH-2099
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
/api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
/api/core/trigger/ @Mairuis @Yeuoly
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
/api/services/trigger/ @Mairuis @Yeuoly
|
||||
/api/models/trigger.py @Mairuis @Yeuoly
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
/api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
/api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
/api/controllers/trigger/ @Mairuis
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis
|
||||
/api/core/trigger/ @Mairuis
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis
|
||||
/api/services/trigger/ @Mairuis
|
||||
/api/models/trigger.py @Mairuis
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis
|
||||
/api/libs/schedule_utils.py @Mairuis
|
||||
/api/services/workflow/scheduler.py @Mairuis
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis
|
||||
|
||||
# Backend - Async Workflow
|
||||
/api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
/api/services/async_workflow_service.py @Mairuis
|
||||
/api/tasks/async_workflow_tasks.py @Mairuis
|
||||
|
||||
# Backend - Billing
|
||||
/api/services/billing_service.py @hj24 @zyssyz123
|
||||
|
||||
5
.github/actions/setup-web/action.yml
vendored
5
.github/actions/setup-web/action.yml
vendored
@ -1,8 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Set up Node.js, Vite+, pnpm, and web dependencies
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5
|
||||
with:
|
||||
run_install: false
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
|
||||
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
73
.github/scripts/check-hotfix-cherry-picks.sh
vendored
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
BASE_SHA=${BASE_SHA:-}
|
||||
HEAD_SHA=${HEAD_SHA:-}
|
||||
MAIN_REF=${MAIN_REF:-origin/main}
|
||||
REMEDIATION_HINT="Changes should be made from the main branch using git cherry-pick -x."
|
||||
|
||||
error() {
|
||||
printf 'ERROR: %s\n' "$1" >&2
|
||||
}
|
||||
|
||||
if [[ -z "$BASE_SHA" || -z "$HEAD_SHA" ]]; then
|
||||
error "BASE_SHA and HEAD_SHA are required. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$BASE_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Base commit '$BASE_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$HEAD_SHA^{commit}" > /dev/null 2>&1; then
|
||||
error "Head commit '$HEAD_SHA' is not available in the local git checkout."
|
||||
exit 2
|
||||
fi
|
||||
|
||||
if ! git rev-parse --verify "$MAIN_REF^{commit}" > /dev/null 2>&1; then
|
||||
error "Main ref '$MAIN_REF' is not available in the local git checkout. $REMEDIATION_HINT"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
failed=0
|
||||
checked=0
|
||||
|
||||
while IFS= read -r commit_sha; do
|
||||
[[ -n "$commit_sha" ]] || continue
|
||||
|
||||
checked=$((checked + 1))
|
||||
subject=$(git log -1 --format=%s "$commit_sha")
|
||||
source_sha=$(
|
||||
git log -1 --format=%B "$commit_sha" \
|
||||
| sed -nE 's/^\(cherry picked from commit ([0-9a-fA-F]{7,64})\)$/\1/p' \
|
||||
| tail -n 1
|
||||
)
|
||||
|
||||
if [[ -z "$source_sha" ]]; then
|
||||
error "Commit $commit_sha ($subject) is missing cherry-pick provenance. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git cat-file -e "$source_sha^{commit}" 2> /dev/null; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that commit is not available locally. $REMEDIATION_HINT"
|
||||
failed=1
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! git merge-base --is-ancestor "$source_sha" "$MAIN_REF"; then
|
||||
error "Commit $commit_sha ($subject) references source $source_sha, but that source is not reachable from main ($MAIN_REF). $REMEDIATION_HINT"
|
||||
failed=1
|
||||
fi
|
||||
done < <(git rev-list --reverse "$BASE_SHA..$HEAD_SHA")
|
||||
|
||||
if [[ "$failed" -ne 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$checked" -eq 0 ]]; then
|
||||
echo "No PR commits to check."
|
||||
else
|
||||
echo "Verified $checked PR commit(s) include cherry-pick provenance from main."
|
||||
fi
|
||||
709
.github/scripts/reset-test-env.sh
vendored
709
.github/scripts/reset-test-env.sh
vendored
@ -1,709 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -Eeuo pipefail
|
||||
|
||||
SCRIPT_NAME="$(basename "$0")"
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
|
||||
DEFAULT_REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd -P)"
|
||||
|
||||
REPO_ROOT="${DIFY_REPO_ROOT:-$DEFAULT_REPO_ROOT}"
|
||||
YES=false
|
||||
DRY_RUN=true
|
||||
SKIP_SMOKE=false
|
||||
SKIP_MIGRATION=false
|
||||
TIMEOUT_SECONDS="${DIFY_RESET_TIMEOUT_SECONDS:-300}"
|
||||
SMOKE_URL="${DIFY_RESET_SMOKE_URL:-}"
|
||||
LOCK_DIR=""
|
||||
CURRENT_PHASE="init"
|
||||
|
||||
DELETED_PATHS=()
|
||||
SKIPPED_PATHS=()
|
||||
DELETED_NAMED_VOLUMES=()
|
||||
SKIPPED_NAMED_VOLUMES=()
|
||||
PRESERVED_PATHS=()
|
||||
HEALTH_RESULTS=()
|
||||
SMOKE_RESULT="not-run"
|
||||
START_TIME="$(date +%s)"
|
||||
|
||||
RUNTIME_PATHS=(
|
||||
"volumes/db/data"
|
||||
"volumes/mysql/data"
|
||||
"volumes/redis/data"
|
||||
"volumes/app/storage"
|
||||
"volumes/plugin_daemon"
|
||||
"volumes/weaviate"
|
||||
"volumes/qdrant"
|
||||
"volumes/pgvector"
|
||||
"volumes/pgvecto_rs"
|
||||
"volumes/chroma"
|
||||
"volumes/milvus"
|
||||
"volumes/opensearch/data"
|
||||
)
|
||||
|
||||
NAMED_VOLUMES=(
|
||||
"dify_es01_data"
|
||||
)
|
||||
|
||||
PRESERVE_PATHS=(
|
||||
".env"
|
||||
"middleware.env"
|
||||
"docker-compose.yaml"
|
||||
"docker-compose.middleware.yaml"
|
||||
"nginx"
|
||||
"ssrf_proxy"
|
||||
"volumes/certbot"
|
||||
"volumes/opensearch/opensearch_dashboards.yml"
|
||||
"nginx/ssl"
|
||||
)
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: $SCRIPT_NAME [options]
|
||||
|
||||
Safely reset a Dify test environment in place. The command defaults to dry-run.
|
||||
|
||||
Options:
|
||||
--yes Perform destructive reset. Required to delete data.
|
||||
--dry-run Print planned actions without changing services or data.
|
||||
--repo-root PATH Repository root. Defaults to auto-detected Dify root.
|
||||
--smoke-url URL Public URL to verify after restart.
|
||||
--skip-smoke Skip public-domain smoke verification.
|
||||
--skip-migration Skip explicit migration gate.
|
||||
--timeout SECONDS Health-check timeout. Default: $TIMEOUT_SECONDS.
|
||||
-h, --help Show this help.
|
||||
|
||||
Required for destructive reset:
|
||||
ALLOW_DIFY_TEST_RESET=true
|
||||
DIFY_ENV_NAME=test
|
||||
|
||||
Optional:
|
||||
DIFY_RESET_SMOKE_URL=https://test.example.com
|
||||
RESET_TARGET_DOMAIN=test.example.com
|
||||
EOF
|
||||
}
|
||||
|
||||
log() {
|
||||
printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"
|
||||
}
|
||||
|
||||
fail() {
|
||||
local message="$1"
|
||||
print_report "failure"
|
||||
printf 'ERROR: %s\n' "$message" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
run_cmd() {
|
||||
printf '+'
|
||||
printf ' %q' "$@"
|
||||
printf '\n'
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
set +e
|
||||
"$@"
|
||||
local status=$?
|
||||
set -e
|
||||
if [ "$status" -ne 0 ]; then
|
||||
fail "Command failed with exit code $status: $(command_string "$@")"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
command_string() {
|
||||
local arg
|
||||
local result=""
|
||||
for arg in "$@"; do
|
||||
result="$result $(printf '%q' "$arg")"
|
||||
done
|
||||
printf '%s' "${result# }"
|
||||
}
|
||||
|
||||
read_env_value() {
|
||||
local key="$1"
|
||||
local default_value="$2"
|
||||
local env_file="$DOCKER_DIR/.env"
|
||||
local value=""
|
||||
|
||||
if [ -f "$env_file" ]; then
|
||||
value="$(awk -F= -v key="$key" '
|
||||
$0 !~ /^[[:space:]]*#/ && $1 == key {
|
||||
sub(/^[^=]*=/, "")
|
||||
print
|
||||
}
|
||||
' "$env_file" | tail -n 1)"
|
||||
fi
|
||||
|
||||
if [ -z "$value" ]; then
|
||||
printf '%s' "$default_value"
|
||||
return
|
||||
fi
|
||||
|
||||
value="${value%\"}"
|
||||
value="${value#\"}"
|
||||
value="${value%\'}"
|
||||
value="${value#\'}"
|
||||
printf '%s' "$value"
|
||||
}
|
||||
|
||||
parse_args() {
|
||||
while [ "$#" -gt 0 ]; do
|
||||
case "$1" in
|
||||
--yes)
|
||||
YES=true
|
||||
DRY_RUN=false
|
||||
;;
|
||||
--dry-run)
|
||||
DRY_RUN=true
|
||||
;;
|
||||
--repo-root)
|
||||
[ "$#" -ge 2 ] || fail "--repo-root requires a path"
|
||||
REPO_ROOT="$2"
|
||||
shift
|
||||
;;
|
||||
--smoke-url)
|
||||
[ "$#" -ge 2 ] || fail "--smoke-url requires a URL"
|
||||
SMOKE_URL="$2"
|
||||
shift
|
||||
;;
|
||||
--skip-smoke)
|
||||
SKIP_SMOKE=true
|
||||
;;
|
||||
--skip-migration)
|
||||
SKIP_MIGRATION=true
|
||||
;;
|
||||
--timeout)
|
||||
[ "$#" -ge 2 ] || fail "--timeout requires seconds"
|
||||
TIMEOUT_SECONDS="$2"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
fail "Unknown option: $1"
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
}
|
||||
|
||||
validate_number() {
|
||||
case "$TIMEOUT_SECONDS" in
|
||||
''|*[!0-9]*)
|
||||
fail "--timeout must be a positive integer"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
require_docker() {
|
||||
command -v docker >/dev/null 2>&1 || fail "docker command not found"
|
||||
docker compose version >/dev/null 2>&1 || fail "docker compose is not available"
|
||||
}
|
||||
|
||||
validate_environment() {
|
||||
CURRENT_PHASE="validate"
|
||||
REPO_ROOT="$(cd "$REPO_ROOT" && pwd -P)"
|
||||
DOCKER_DIR="$REPO_ROOT/docker"
|
||||
|
||||
[ -d "$DOCKER_DIR" ] || fail "Docker directory not found: $DOCKER_DIR"
|
||||
[ -f "$DOCKER_DIR/docker-compose.yaml" ] || fail "docker-compose.yaml not found in $DOCKER_DIR"
|
||||
[ -f "$DOCKER_DIR/.env" ] || fail ".env not found in $DOCKER_DIR"
|
||||
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
[ "$YES" = true ] || fail "Destructive reset requires --yes"
|
||||
[ "${ALLOW_DIFY_TEST_RESET:-}" = "true" ] || fail "ALLOW_DIFY_TEST_RESET=true is required"
|
||||
[ "${DIFY_ENV_NAME:-}" = "test" ] || fail "DIFY_ENV_NAME=test is required"
|
||||
require_docker
|
||||
fi
|
||||
}
|
||||
|
||||
acquire_lock() {
|
||||
CURRENT_PHASE="lock"
|
||||
local env_name="${DIFY_ENV_NAME:-dry-run}"
|
||||
LOCK_DIR="${TMPDIR:-/tmp}/dify-test-reset-${env_name}.lock"
|
||||
|
||||
if ! mkdir "$LOCK_DIR" 2>/dev/null; then
|
||||
fail "Reset lock is already held: $LOCK_DIR"
|
||||
fi
|
||||
|
||||
printf '%s\n' "$$" > "$LOCK_DIR/pid"
|
||||
trap cleanup EXIT
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
if [ -n "$LOCK_DIR" ] && [ -d "$LOCK_DIR" ]; then
|
||||
rm -rf "$LOCK_DIR"
|
||||
fi
|
||||
}
|
||||
|
||||
compose() {
|
||||
local args=(compose --env-file "$DOCKER_DIR/.env" -f "$DOCKER_DIR/docker-compose.yaml")
|
||||
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
|
||||
args+=(-p "$DIFY_COMPOSE_PROJECT")
|
||||
fi
|
||||
|
||||
docker "${args[@]}" "$@"
|
||||
}
|
||||
|
||||
compose_project_name() {
|
||||
if [ -n "${DIFY_COMPOSE_PROJECT:-}" ]; then
|
||||
printf '%s' "$DIFY_COMPOSE_PROJECT"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "${COMPOSE_PROJECT_NAME:-}" ]; then
|
||||
printf '%s' "$COMPOSE_PROJECT_NAME"
|
||||
return
|
||||
fi
|
||||
|
||||
local env_project
|
||||
env_project="$(read_env_value COMPOSE_PROJECT_NAME "")"
|
||||
if [ -n "$env_project" ]; then
|
||||
printf '%s' "$env_project"
|
||||
return
|
||||
fi
|
||||
|
||||
basename "$DOCKER_DIR"
|
||||
}
|
||||
|
||||
active_db_service() {
|
||||
local db_type
|
||||
db_type="$(read_env_value DB_TYPE postgresql)"
|
||||
case "$db_type" in
|
||||
postgresql|'')
|
||||
printf '%s\n' "db_postgres"
|
||||
;;
|
||||
mysql)
|
||||
printf '%s\n' "db_mysql"
|
||||
;;
|
||||
oceanbase)
|
||||
printf '%s\n' "oceanbase"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "$db_type"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
active_vector_service() {
|
||||
local vector_store
|
||||
vector_store="$(read_env_value VECTOR_STORE weaviate)"
|
||||
case "$vector_store" in
|
||||
''|none|external)
|
||||
return 0
|
||||
;;
|
||||
pgvecto-rs|pgvecto_rs)
|
||||
printf '%s\n' "pgvecto-rs"
|
||||
;;
|
||||
milvus)
|
||||
printf '%s\n' "milvus-standalone"
|
||||
;;
|
||||
elasticsearch|opensearch|weaviate|qdrant|pgvector|chroma|oceanbase|seekdb|couchbase-server|iris)
|
||||
printf '%s\n' "$vector_store"
|
||||
;;
|
||||
couchbase)
|
||||
printf '%s\n' "couchbase-server"
|
||||
;;
|
||||
*)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
safe_runtime_path() {
|
||||
local rel_path="$1"
|
||||
case "$rel_path" in
|
||||
""|"/"| "." | ".." | *".."* | /*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
case "$rel_path" in
|
||||
volumes/*)
|
||||
return 0
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
safe_named_volume() {
|
||||
local volume="$1"
|
||||
case "$volume" in
|
||||
""|*"/"*|*" "*|*$'\t'*|*$'\n'*|*$'\r'*)
|
||||
return 1
|
||||
;;
|
||||
*[!a-zA-Z0-9_.-]*)
|
||||
return 1
|
||||
;;
|
||||
*)
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
volume_exists() {
|
||||
docker volume inspect "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
append_unique_volume() {
|
||||
local candidate="$1"
|
||||
local existing
|
||||
[ -n "$candidate" ] || return 0
|
||||
|
||||
for existing in "${RESOLVED_VOLUME_NAMES[@]}"; do
|
||||
if [ "$existing" = "$candidate" ]; then
|
||||
return
|
||||
fi
|
||||
done
|
||||
|
||||
RESOLVED_VOLUME_NAMES+=("$candidate")
|
||||
}
|
||||
|
||||
resolve_named_volume_names() {
|
||||
local logical_name="$1"
|
||||
local project_name
|
||||
local candidate
|
||||
local volume_list
|
||||
local status
|
||||
RESOLVED_VOLUME_NAMES=()
|
||||
|
||||
project_name="$(compose_project_name)"
|
||||
|
||||
set +e
|
||||
volume_list="$(docker volume ls -q \
|
||||
--filter "label=com.docker.compose.project=$project_name" \
|
||||
--filter "label=com.docker.compose.volume=$logical_name" 2>/dev/null)"
|
||||
status=$?
|
||||
set -e
|
||||
|
||||
if [ "$status" -ne 0 ]; then
|
||||
fail "Failed to list Docker volumes for Compose project $project_name"
|
||||
fi
|
||||
|
||||
while IFS= read -r candidate; do
|
||||
append_unique_volume "$candidate"
|
||||
done <<< "$volume_list"
|
||||
|
||||
for candidate in "${project_name}_${logical_name}" "$logical_name"; do
|
||||
if volume_exists "$candidate"; then
|
||||
append_unique_volume "$candidate"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
collect_preserved_paths() {
|
||||
PRESERVED_PATHS=()
|
||||
local rel_path
|
||||
for rel_path in "${PRESERVE_PATHS[@]}"; do
|
||||
if [ -e "$DOCKER_DIR/$rel_path" ]; then
|
||||
PRESERVED_PATHS+=("$rel_path")
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
print_plan() {
|
||||
CURRENT_PHASE="plan"
|
||||
local db_service
|
||||
local vector_service
|
||||
db_service="$(active_db_service)"
|
||||
vector_service="$(active_vector_service || true)"
|
||||
|
||||
collect_preserved_paths
|
||||
|
||||
log "Reset mode: $([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
|
||||
log "Repository root: $REPO_ROOT"
|
||||
log "Docker directory: $DOCKER_DIR"
|
||||
log "Compose project: $(compose_project_name)"
|
||||
log "Database service: $db_service"
|
||||
log "Vector service: ${vector_service:-<external-or-none>}"
|
||||
log "Timeout: ${TIMEOUT_SECONDS}s"
|
||||
|
||||
printf '\nPlanned runtime path deletions:\n'
|
||||
local rel_path
|
||||
for rel_path in "${RUNTIME_PATHS[@]}"; do
|
||||
printf ' - %s\n' "$rel_path"
|
||||
done
|
||||
|
||||
printf '\nPlanned named volume deletions:\n'
|
||||
local volume
|
||||
for volume in "${NAMED_VOLUMES[@]}"; do
|
||||
printf ' - %s (Compose project: %s)\n' "$volume" "$(compose_project_name)"
|
||||
done
|
||||
|
||||
printf '\nPreserved configuration paths found:\n'
|
||||
for rel_path in "${PRESERVED_PATHS[@]}"; do
|
||||
printf ' - %s\n' "$rel_path"
|
||||
done
|
||||
|
||||
printf '\nCommands:\n'
|
||||
printf ' - docker compose down --remove-orphans\n'
|
||||
printf ' - delete allowlisted runtime paths and named volumes\n'
|
||||
printf ' - docker compose up -d %s redis%s\n' "$db_service" "${vector_service:+ $vector_service}"
|
||||
printf ' - docker compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api\n'
|
||||
printf ' - docker compose up -d\n'
|
||||
printf ' - health checks and smoke check\n\n'
|
||||
}
|
||||
|
||||
delete_runtime_paths() {
|
||||
CURRENT_PHASE="delete-runtime-data"
|
||||
local rel_path
|
||||
local abs_path
|
||||
|
||||
for rel_path in "${RUNTIME_PATHS[@]}"; do
|
||||
safe_runtime_path "$rel_path" || fail "Unsafe runtime path in allowlist: $rel_path"
|
||||
abs_path="$DOCKER_DIR/$rel_path"
|
||||
|
||||
if [ ! -e "$abs_path" ]; then
|
||||
SKIPPED_PATHS+=("$rel_path (absent)")
|
||||
continue
|
||||
fi
|
||||
|
||||
DELETED_PATHS+=("$rel_path")
|
||||
run_cmd rm -rf -- "$abs_path"
|
||||
done
|
||||
}
|
||||
|
||||
delete_named_volumes() {
|
||||
CURRENT_PHASE="delete-runtime-volumes"
|
||||
local logical_name
|
||||
local actual_name
|
||||
|
||||
for logical_name in "${NAMED_VOLUMES[@]}"; do
|
||||
safe_named_volume "$logical_name" || fail "Unsafe named volume in allowlist: $logical_name"
|
||||
resolve_named_volume_names "$logical_name"
|
||||
|
||||
if [ "${#RESOLVED_VOLUME_NAMES[@]}" -eq 0 ]; then
|
||||
SKIPPED_NAMED_VOLUMES+=("$logical_name (absent)")
|
||||
continue
|
||||
fi
|
||||
|
||||
for actual_name in "${RESOLVED_VOLUME_NAMES[@]}"; do
|
||||
DELETED_NAMED_VOLUMES+=("$actual_name")
|
||||
run_cmd docker volume rm "$actual_name"
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
stop_stack() {
|
||||
CURRENT_PHASE="stop-stack"
|
||||
run_cmd compose down --remove-orphans
|
||||
}
|
||||
|
||||
start_middleware() {
|
||||
CURRENT_PHASE="start-middleware"
|
||||
local db_service
|
||||
local vector_service
|
||||
local services=()
|
||||
|
||||
db_service="$(active_db_service)"
|
||||
vector_service="$(active_vector_service || true)"
|
||||
services+=("$db_service" "redis")
|
||||
|
||||
if [ -n "$vector_service" ]; then
|
||||
services+=("$vector_service")
|
||||
fi
|
||||
|
||||
run_cmd compose up -d "${services[@]}"
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
wait_for_services "${services[@]}"
|
||||
fi
|
||||
}
|
||||
|
||||
run_migration() {
|
||||
CURRENT_PHASE="migration"
|
||||
if [ "$SKIP_MIGRATION" = true ]; then
|
||||
HEALTH_RESULTS+=("migration:skipped")
|
||||
return
|
||||
fi
|
||||
|
||||
run_cmd compose run --rm -e MIGRATION_ENABLED=true -e MODE=migration api
|
||||
HEALTH_RESULTS+=("migration:ok")
|
||||
}
|
||||
|
||||
start_full_stack() {
|
||||
CURRENT_PHASE="start-full-stack"
|
||||
run_cmd compose up -d
|
||||
|
||||
if [ "$DRY_RUN" = false ]; then
|
||||
wait_for_services api web worker nginx
|
||||
wait_if_service_exists plugin_daemon
|
||||
fi
|
||||
}
|
||||
|
||||
container_status() {
|
||||
local service="$1"
|
||||
local container_id
|
||||
container_id="$(compose ps -q "$service" 2>/dev/null || true)"
|
||||
[ -n "$container_id" ] || return 1
|
||||
|
||||
local health
|
||||
health="$(docker inspect --format '{{if .State.Health}}{{.State.Health.Status}}{{else}}{{.State.Status}}{{end}}' "$container_id" 2>/dev/null || true)"
|
||||
printf '%s' "$health"
|
||||
}
|
||||
|
||||
wait_if_service_exists() {
|
||||
local service="$1"
|
||||
if [ -n "$(compose ps -q "$service" 2>/dev/null || true)" ]; then
|
||||
wait_for_services "$service"
|
||||
fi
|
||||
}
|
||||
|
||||
wait_for_services() {
|
||||
local service
|
||||
for service in "$@"; do
|
||||
wait_for_service "$service"
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local service="$1"
|
||||
local deadline=$(( $(date +%s) + TIMEOUT_SECONDS ))
|
||||
local status=""
|
||||
|
||||
log "Waiting for service: $service"
|
||||
while [ "$(date +%s)" -le "$deadline" ]; do
|
||||
status="$(container_status "$service" || true)"
|
||||
case "$status" in
|
||||
healthy|running)
|
||||
HEALTH_RESULTS+=("$service:$status")
|
||||
return 0
|
||||
;;
|
||||
unhealthy|exited|dead)
|
||||
HEALTH_RESULTS+=("$service:$status")
|
||||
fail "Service $service reached failure status: $status"
|
||||
;;
|
||||
esac
|
||||
sleep 3
|
||||
done
|
||||
|
||||
HEALTH_RESULTS+=("$service:timeout")
|
||||
fail "Timed out waiting for service: $service"
|
||||
}
|
||||
|
||||
default_smoke_url() {
|
||||
if [ -n "$SMOKE_URL" ]; then
|
||||
printf '%s' "$SMOKE_URL"
|
||||
return
|
||||
fi
|
||||
|
||||
if [ -n "${RESET_TARGET_DOMAIN:-}" ]; then
|
||||
local https_enabled
|
||||
https_enabled="$(read_env_value NGINX_HTTPS_ENABLED false)"
|
||||
if [ "$https_enabled" = "true" ]; then
|
||||
printf 'https://%s' "$RESET_TARGET_DOMAIN"
|
||||
else
|
||||
printf 'http://%s' "$RESET_TARGET_DOMAIN"
|
||||
fi
|
||||
return
|
||||
fi
|
||||
|
||||
local port
|
||||
port="$(read_env_value EXPOSE_NGINX_PORT 80)"
|
||||
printf 'http://localhost:%s' "$port"
|
||||
}
|
||||
|
||||
run_smoke_check() {
|
||||
CURRENT_PHASE="smoke"
|
||||
if [ "$SKIP_SMOKE" = true ]; then
|
||||
SMOKE_RESULT="skipped"
|
||||
return
|
||||
fi
|
||||
|
||||
local url
|
||||
url="$(default_smoke_url)"
|
||||
if [ "$DRY_RUN" = true ]; then
|
||||
SMOKE_RESULT="planned:$url"
|
||||
printf '+ curl -fsS --max-time 10 %q\n' "$url"
|
||||
return
|
||||
fi
|
||||
|
||||
curl -fsS --max-time 10 "$url" >/dev/null || fail "Smoke check failed: $url"
|
||||
SMOKE_RESULT="ok:$url"
|
||||
}
|
||||
|
||||
print_report() {
|
||||
local status="${1:-success}"
|
||||
local end_time
|
||||
end_time="$(date +%s)"
|
||||
|
||||
printf '\nReset report\n'
|
||||
printf '============\n'
|
||||
printf 'status: %s\n' "$status"
|
||||
printf 'environment: %s\n' "${DIFY_ENV_NAME:-<unset>}"
|
||||
printf 'repo_root: %s\n' "${REPO_ROOT:-<unset>}"
|
||||
printf 'phase: %s\n' "$CURRENT_PHASE"
|
||||
printf 'duration_seconds: %s\n' "$(( end_time - START_TIME ))"
|
||||
printf 'mode: %s\n' "$([ "$DRY_RUN" = true ] && printf dry-run || printf destructive)"
|
||||
|
||||
printf '\ndeleted_runtime_paths:\n'
|
||||
if [ "${#DELETED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${DELETED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nskipped_runtime_paths:\n'
|
||||
if [ "${#SKIPPED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${SKIPPED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\ndeleted_named_volumes:\n'
|
||||
if [ "${#DELETED_NAMED_VOLUMES[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${DELETED_NAMED_VOLUMES[@]}"
|
||||
fi
|
||||
|
||||
printf '\nskipped_named_volumes:\n'
|
||||
if [ "${#SKIPPED_NAMED_VOLUMES[@]}" -eq 0 ]; then
|
||||
printf ' - <none>\n'
|
||||
else
|
||||
printf ' - %s\n' "${SKIPPED_NAMED_VOLUMES[@]}"
|
||||
fi
|
||||
|
||||
printf '\npreserved_paths:\n'
|
||||
if [ "${#PRESERVED_PATHS[@]}" -eq 0 ]; then
|
||||
printf ' - <none found>\n'
|
||||
else
|
||||
printf ' - %s\n' "${PRESERVED_PATHS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nhealth_results:\n'
|
||||
if [ "${#HEALTH_RESULTS[@]}" -eq 0 ]; then
|
||||
printf ' - <not run>\n'
|
||||
else
|
||||
printf ' - %s\n' "${HEALTH_RESULTS[@]}"
|
||||
fi
|
||||
|
||||
printf '\nsmoke_result: %s\n' "$SMOKE_RESULT"
|
||||
}
|
||||
|
||||
main() {
|
||||
parse_args "$@"
|
||||
validate_number
|
||||
validate_environment
|
||||
acquire_lock
|
||||
print_plan
|
||||
|
||||
if [ "$DRY_RUN" = true ]; then
|
||||
run_smoke_check
|
||||
print_report "dry-run"
|
||||
return 0
|
||||
fi
|
||||
|
||||
stop_stack
|
||||
delete_runtime_paths
|
||||
delete_named_volumes
|
||||
start_middleware
|
||||
run_migration
|
||||
start_full_stack
|
||||
run_smoke_check
|
||||
CURRENT_PHASE="complete"
|
||||
print_report "success"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
42
.github/workflows/api-tests.yml
vendored
42
.github/workflows/api-tests.yml
vendored
@ -48,10 +48,23 @@ jobs:
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
run: uv run --project api pytest api/tests/unit_tests/configs/test_env_consistency.py
|
||||
|
||||
- name: Run Unit Tests
|
||||
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
-n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers
|
||||
# Controller tests register Flask routes at import time, so keep them out of xdist.
|
||||
uv run --project api pytest \
|
||||
--timeout "${PYTEST_TIMEOUT:-20}" \
|
||||
--cov-append \
|
||||
api/tests/unit_tests/controllers
|
||||
|
||||
- name: Upload unit coverage data
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
@ -96,32 +109,11 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Integration Tests
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
-p no:benchmark \
|
||||
--start-middleware \
|
||||
-n auto \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/workflow \
|
||||
|
||||
6
.github/workflows/autofix.yml
vendored
6
.github/workflows/autofix.yml
vendored
@ -120,7 +120,11 @@ jobs:
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd api
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir openapi --markdown-dir openapi/markdown
|
||||
uv run dev/generate_swagger_markdown_docs.py --swagger-dir ../packages/contracts/openapi --markdown-dir openapi/markdown --keep-swagger-json
|
||||
|
||||
- name: Generate frontend contracts
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: pnpm --dir packages/contracts gen-api-contract-from-openapi
|
||||
|
||||
- name: ESLint autofix
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
|
||||
17
.github/workflows/expose_service_ports.sh
vendored
17
.github/workflows/expose_service_ports.sh
vendored
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.pgvector.ports += ["5433:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
49
.github/workflows/hotfix-cherry-pick.yml
vendored
Normal file
@ -0,0 +1,49 @@
|
||||
name: Hotfix Cherry-Pick Provenance
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- 'hotfix/**'
|
||||
- 'lts/**'
|
||||
types:
|
||||
- opened
|
||||
- edited
|
||||
- reopened
|
||||
- ready_for_review
|
||||
- synchronize
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: hotfix-cherry-pick-${{ github.event.pull_request.number || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-cherry-pick-provenance:
|
||||
name: Require cherry-pick provenance
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch PR base, PR head, and main
|
||||
env:
|
||||
BASE_REF: ${{ github.base_ref }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
run: |
|
||||
git fetch --no-tags --prune origin \
|
||||
"+refs/heads/main:refs/remotes/origin/main" \
|
||||
"+refs/heads/${BASE_REF}:refs/remotes/origin/${BASE_REF}" \
|
||||
"+refs/pull/${PR_NUMBER}/head:refs/remotes/pull/${PR_NUMBER}/head"
|
||||
|
||||
- name: Load checker from main
|
||||
run: git show origin/main:.github/scripts/check-hotfix-cherry-picks.sh > "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
|
||||
- name: Check PR commits
|
||||
env:
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
MAIN_REF: origin/main
|
||||
run: bash "$RUNNER_TEMP/check-hotfix-cherry-picks.sh"
|
||||
6
.github/workflows/main-ci.yml
vendored
6
.github/workflows/main-ci.yml
vendored
@ -55,7 +55,6 @@ jobs:
|
||||
api:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
@ -90,11 +89,13 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/conftest.py'
|
||||
- 'api/tests/pytest_dify.py'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.pytest.ports.yaml'
|
||||
- 'docker/docker-compose.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
- 'docker/generate_docker_compose'
|
||||
@ -114,7 +115,6 @@ jobs:
|
||||
- 'api/migrations/**'
|
||||
- 'api/.env.example'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/envs/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
|
||||
14
.github/workflows/style.yml
vendored
14
.github/workflows/style.yml
vendored
@ -77,6 +77,8 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
e2e/**
|
||||
sdks/nodejs-client/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
@ -94,14 +96,14 @@ jobs:
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
path: .eslintcache
|
||||
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
working-directory: .
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
@ -113,7 +115,7 @@ jobs:
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
working-directory: .
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
@ -125,7 +127,7 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
path: .eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
|
||||
39
.github/workflows/vdb-tests-full.yml
vendored
39
.github/workflows/vdb-tests-full.yml
vendored
@ -48,14 +48,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -64,32 +56,13 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--vdb-services "weaviate,qdrant,couchbase-server,etcd,minio,milvus-standalone,pgvecto-rs,pgvector,chroma,elasticsearch,oceanbase" \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/*/tests/integration_tests
|
||||
|
||||
31
.github/workflows/vdb-tests.yml
vendored
31
.github/workflows/vdb-tests.yml
vendored
@ -45,14 +45,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/envs/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
@ -61,31 +53,14 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
pgvector
|
||||
chroma
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
uv run --project api pytest \
|
||||
--start-vdb \
|
||||
--timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -203,6 +203,7 @@ sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/settings.example.json
|
||||
!.vscode/README.md
|
||||
api/.vscode
|
||||
# vscode Code History Extension
|
||||
@ -249,3 +250,5 @@ scripts/stress-test/reports/
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.context/*
|
||||
.eslintcache
|
||||
|
||||
@ -56,44 +56,9 @@ if $api_modified; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if $web_modified; then
|
||||
if $skip_web_checks; then
|
||||
echo "Git operation in progress, skipping web checks"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Running ESLint on web module"
|
||||
|
||||
if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then
|
||||
web_ts_modified=false
|
||||
else
|
||||
ts_diff_status=$?
|
||||
if [ $ts_diff_status -eq 1 ]; then
|
||||
web_ts_modified=true
|
||||
else
|
||||
echo "Unable to determine staged TypeScript changes (git exit code: $ts_diff_status)."
|
||||
exit $ts_diff_status
|
||||
fi
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
pnpm exec vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! npm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! npm run knip; then
|
||||
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ../
|
||||
if $skip_web_checks; then
|
||||
echo "Git operation in progress, skipping web checks"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
vp staged
|
||||
|
||||
@ -9,6 +9,7 @@ The codebase is split into:
|
||||
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
|
||||
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
|
||||
- **Docker deployment** (`/docker`): Containerized deployment configurations
|
||||
- **Dify Agent Backend** (`/dify-agent`): Backend services for managing and executing agent
|
||||
|
||||
## Backend Workflow
|
||||
|
||||
|
||||
61
Makefile
61
Makefile
@ -83,16 +83,15 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "📝 Running core type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude '(^|/)conftest\.py$$' --exclude 'tests/' --exclude 'migrations/' --exclude 'dev/generate_swagger_specs.py' --exclude 'dev/generate_fastopenapi_specs.py' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@ -101,7 +100,46 @@ test:
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
|
||||
echo "Running backend unit tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers; \
|
||||
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
|
||||
api/tests/unit_tests/controllers; \
|
||||
fi
|
||||
@echo "✅ Unit tests complete"
|
||||
|
||||
test-all:
|
||||
@echo "🧪 Running full backend test suite..."
|
||||
@if [ -n "$(TARGET_TESTS)" ]; then \
|
||||
echo "Target: $(TARGET_TESTS)"; \
|
||||
uv run --project api --dev pytest $(TARGET_TESTS); \
|
||||
else \
|
||||
echo "Running backend unit tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --timeout "$${PYTEST_TIMEOUT:-20}" -n auto \
|
||||
api/tests/unit_tests \
|
||||
api/providers/vdb/*/tests/unit_tests \
|
||||
api/providers/trace/*/tests/unit_tests \
|
||||
--ignore=api/tests/unit_tests/controllers; \
|
||||
uv run --project api --dev pytest --timeout "$${PYTEST_TIMEOUT:-20}" --cov-append \
|
||||
api/tests/unit_tests/controllers; \
|
||||
echo "Running backend integration tests"; \
|
||||
uv run --project api --dev pytest -p no:benchmark --start-middleware -n auto \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/tests/integration_tests/workflow \
|
||||
api/tests/integration_tests/tools \
|
||||
api/tests/test_containers_integration_tests; \
|
||||
echo "Running VDB smoke tests"; \
|
||||
uv run --project api --dev pytest --start-vdb \
|
||||
--timeout "$${PYTEST_TIMEOUT:-180}" \
|
||||
--cov-append \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests; \
|
||||
fi
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
@ -153,9 +191,10 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo " make test-all - Run full backend tests, including Docker-backed suites"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@ -165,4 +204,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
|
||||
@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
|
||||
@ -180,6 +180,8 @@ Quick checks while iterating:
|
||||
- Format: `make format`
|
||||
- Lint (includes auto-fix): `make lint`
|
||||
- Type check: `make type-check`
|
||||
- Unit tests: `make test`
|
||||
- Full backend tests, including Docker-backed suites: `make test-all`
|
||||
- Targeted tests: `make test TARGET_TESTS=./api/tests/<target_tests>`
|
||||
|
||||
Before opening a PR / submitting:
|
||||
@ -195,7 +197,7 @@ Before opening a PR / submitting:
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize with `ResponseModel.model_validate(...).model_dump(...)`,
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
and avoid adding new legacy `ns.model(...)`, `@marshal_with(...)`, or GET `@ns.expect(...)` patterns.
|
||||
|
||||
### Miscellaneous
|
||||
|
||||
@ -17,14 +17,15 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
git g++ \
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev --group evaluation
|
||||
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
@ -77,7 +78,6 @@ RUN \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
git \
|
||||
nodejs=${NODE_PACKAGE_VERSION} \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
|
||||
@ -99,7 +99,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
uv run pyrefly check # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
@ -117,7 +117,7 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
logger.warning("Failed to add trace headers to response", exc_info=True)
|
||||
return response
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
# Capture the decorator return values so static checkers do not treat the hooks as unused.
|
||||
_ = before_request
|
||||
_ = add_trace_headers
|
||||
|
||||
|
||||
1
api/clients/__init__.py
Normal file
1
api/clients/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""External service client packages."""
|
||||
74
api/clients/agent_backend/__init__.py
Normal file
74
api/clients/agent_backend/__init__.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""API-side integration boundary for the Dify Agent backend.
|
||||
|
||||
Public wire DTOs come from ``dify_agent.protocol``. This package only contains
|
||||
API adapters: request building from Dify product concepts, a thin client wrapper,
|
||||
event adaptation for future workflow integration, and deterministic fakes.
|
||||
"""
|
||||
|
||||
from clients.agent_backend.client import AgentBackendRunClient, DifyAgentBackendRunClient
|
||||
from clients.agent_backend.errors import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
AgentBackendRequestBuildError,
|
||||
AgentBackendRunFailedError,
|
||||
AgentBackendStreamError,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
)
|
||||
from clients.agent_backend.event_adapter import (
|
||||
AgentBackendInternalEvent,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunCancelledInternalEvent,
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendRunFailedInternalEvent,
|
||||
AgentBackendRunPausedInternalEvent,
|
||||
AgentBackendRunStartedInternalEvent,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
AgentBackendStreamInternalEvent,
|
||||
)
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
"AgentBackendInternalEventType",
|
||||
"AgentBackendModelConfig",
|
||||
"AgentBackendOutputConfig",
|
||||
"AgentBackendRequestBuildError",
|
||||
"AgentBackendRunCancelledInternalEvent",
|
||||
"AgentBackendRunClient",
|
||||
"AgentBackendRunEventAdapter",
|
||||
"AgentBackendRunFailedError",
|
||||
"AgentBackendRunFailedInternalEvent",
|
||||
"AgentBackendRunPausedInternalEvent",
|
||||
"AgentBackendRunRequestBuilder",
|
||||
"AgentBackendRunStartedInternalEvent",
|
||||
"AgentBackendRunSucceededInternalEvent",
|
||||
"AgentBackendStreamError",
|
||||
"AgentBackendStreamInternalEvent",
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
130
api/clients/agent_backend/client.py
Normal file
130
api/clients/agent_backend/client.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""Synchronous API-side wrapper around the public ``dify-agent`` client.
|
||||
|
||||
``dify-agent`` owns the cross-service DTOs and HTTP/SSE implementation. The API
|
||||
backend keeps this thin wrapper so workflow code depends on a local protocol,
|
||||
gets API-native errors, and can use a deterministic fake in tests without
|
||||
creating another wire contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Protocol
|
||||
|
||||
from dify_agent.client import (
|
||||
DifyAgentClientError,
|
||||
DifyAgentHTTPError,
|
||||
DifyAgentStreamError,
|
||||
DifyAgentTimeoutError,
|
||||
DifyAgentValidationError,
|
||||
)
|
||||
from dify_agent.protocol import (
|
||||
CancelRunRequest,
|
||||
CancelRunResponse,
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
RunEvent,
|
||||
RunStatusResponse,
|
||||
)
|
||||
|
||||
from clients.agent_backend.errors import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
AgentBackendStreamError,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
)
|
||||
|
||||
|
||||
class AgentBackendRunClient(Protocol):
|
||||
"""Local boundary used by API workflow integrations to run Agent backend jobs."""
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one Agent backend run and return its accepted status."""
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Request explicit cancellation for one Agent backend run."""
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Yield public ``dify-agent`` run events in stream order."""
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Wait for a run to reach a terminal status and return that status."""
|
||||
|
||||
|
||||
class _DifyAgentSyncClient(Protocol):
|
||||
"""Subset of ``dify_agent.client.Client`` used by the API wrapper."""
|
||||
|
||||
def create_run_sync(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one run synchronously."""
|
||||
|
||||
def cancel_run_sync(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Cancel one run synchronously."""
|
||||
|
||||
def stream_events_sync(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Stream run events synchronously."""
|
||||
|
||||
def wait_run_sync(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Wait for terminal run status synchronously."""
|
||||
|
||||
|
||||
class DifyAgentBackendRunClient:
|
||||
"""Adapter from API sync call sites to ``dify_agent.client.Client`` sync methods."""
|
||||
|
||||
client: _DifyAgentSyncClient
|
||||
|
||||
def __init__(self, client: _DifyAgentSyncClient) -> None:
|
||||
self.client = client
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Create one run through ``POST /runs`` and normalize client exceptions."""
|
||||
try:
|
||||
return self.client.create_run_sync(request)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Cancel one run through ``POST /runs/{run_id}/cancel`` and normalize exceptions."""
|
||||
try:
|
||||
return self.client.cancel_run_sync(run_id, request=request)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Stream run events from ``/events/sse`` with the wrapped client's reconnect policy."""
|
||||
try:
|
||||
yield from self.client.stream_events_sync(run_id, after=after)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Poll run status until terminal state and normalize client exceptions."""
|
||||
try:
|
||||
return self.client.wait_run_sync(run_id, timeout_seconds=timeout_seconds)
|
||||
except Exception as exc:
|
||||
raise _normalize_dify_agent_error(exc) from exc
|
||||
|
||||
|
||||
def _normalize_dify_agent_error(exc: Exception) -> AgentBackendError:
|
||||
"""Map public ``dify-agent`` client errors to API-side integration errors."""
|
||||
match exc:
|
||||
case DifyAgentValidationError() as error:
|
||||
return AgentBackendValidationError(
|
||||
"Agent backend request or response validation failed", detail=error.detail
|
||||
)
|
||||
case DifyAgentHTTPError() as error:
|
||||
return AgentBackendHTTPError(
|
||||
f"Agent backend HTTP {error.status_code}",
|
||||
status_code=error.status_code,
|
||||
detail=error.detail,
|
||||
)
|
||||
case DifyAgentTimeoutError() as error:
|
||||
return AgentBackendTransportError(str(error))
|
||||
case DifyAgentStreamError() as error:
|
||||
return AgentBackendStreamError(str(error))
|
||||
case DifyAgentClientError() as error:
|
||||
return AgentBackendTransportError(str(error))
|
||||
case AgentBackendError() as error:
|
||||
return error
|
||||
case _:
|
||||
return AgentBackendTransportError(str(exc) or type(exc).__name__)
|
||||
61
api/clients/agent_backend/errors.py
Normal file
61
api/clients/agent_backend/errors.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""API-side errors for the Dify Agent backend integration.
|
||||
|
||||
The wire protocol and low-level HTTP behaviour are owned by ``dify-agent``.
|
||||
This module only normalizes those client errors into the API backend's boundary
|
||||
so workflow/node code does not depend directly on transport-specific exception
|
||||
classes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AgentBackendError(Exception):
|
||||
"""Base error for API-side Agent backend integration failures."""
|
||||
|
||||
|
||||
class AgentBackendRequestBuildError(AgentBackendError):
|
||||
"""Raised when Dify product/workflow state cannot be mapped to a run request."""
|
||||
|
||||
|
||||
class AgentBackendTransportError(AgentBackendError):
|
||||
"""Raised for timeout or request-level failures talking to Agent backend."""
|
||||
|
||||
|
||||
class AgentBackendHTTPError(AgentBackendTransportError):
|
||||
"""Raised for Agent backend HTTP errors after status/detail normalization."""
|
||||
|
||||
status_code: int
|
||||
detail: object
|
||||
|
||||
def __init__(self, message: str, *, status_code: int, detail: object) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentBackendValidationError(AgentBackendError):
|
||||
"""Raised for local request validation or Agent backend 422 responses."""
|
||||
|
||||
detail: object
|
||||
|
||||
def __init__(self, message: str, *, detail: object) -> None:
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AgentBackendStreamError(AgentBackendError):
|
||||
"""Raised when an Agent backend event stream is malformed or exhausted."""
|
||||
|
||||
|
||||
class AgentBackendRunFailedError(AgentBackendError):
|
||||
"""Raised by callers that choose to translate a terminal failed run into an exception."""
|
||||
|
||||
run_id: str
|
||||
detail: Any
|
||||
|
||||
def __init__(self, run_id: str, detail: Any) -> None:
|
||||
self.run_id = run_id
|
||||
self.detail = detail
|
||||
super().__init__(f"Agent backend run failed: {run_id}")
|
||||
167
api/clients/agent_backend/event_adapter.py
Normal file
167
api/clients/agent_backend/event_adapter.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""Adapt public ``dify-agent`` run events into API-internal event semantics.
|
||||
|
||||
The adapter does not define a new cross-service event contract. It consumes
|
||||
``dify_agent.protocol.RunEvent`` and produces small API-internal models that the
|
||||
future workflow Agent Node can map to Graphon/AppQueue events in phase 3.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.protocol import (
|
||||
PydanticAIStreamRunEvent,
|
||||
RunCancelledEvent,
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunPausedEvent,
|
||||
RunStartedEvent,
|
||||
RunSucceededEvent,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter
|
||||
|
||||
_EVENT_DATA_ADAPTER = TypeAdapter(object)
|
||||
|
||||
|
||||
class AgentBackendInternalEventType(StrEnum):
|
||||
"""API-only event labels used before Graphon/AppQueue integration."""
|
||||
|
||||
RUN_STARTED = "run_started"
|
||||
STREAM_EVENT = "stream_event"
|
||||
RUN_PAUSED = "run_paused"
|
||||
RUN_SUCCEEDED = "run_succeeded"
|
||||
RUN_FAILED = "run_failed"
|
||||
RUN_CANCELLED = "run_cancelled"
|
||||
|
||||
|
||||
class AgentBackendInternalEventBase(BaseModel):
|
||||
"""Common fields preserved from public Dify Agent run events."""
|
||||
|
||||
run_id: str
|
||||
source_event_id: str | None = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class AgentBackendRunStartedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal marker for a started Agent backend run."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_STARTED] = AgentBackendInternalEventType.RUN_STARTED
|
||||
|
||||
|
||||
class AgentBackendStreamInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal wrapper for one pydantic-ai stream event payload."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.STREAM_EVENT] = AgentBackendInternalEventType.STREAM_EVENT
|
||||
event_kind: str | None = None
|
||||
data: JsonValue
|
||||
|
||||
|
||||
class AgentBackendRunSucceededInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal success event carrying final output and session state."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_SUCCEEDED] = AgentBackendInternalEventType.RUN_SUCCEEDED
|
||||
output: JsonValue
|
||||
session_snapshot: CompositorSessionSnapshot
|
||||
|
||||
|
||||
class AgentBackendRunPausedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal resumable pause event for human handoff and Babysit flows."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_PAUSED] = AgentBackendInternalEventType.RUN_PAUSED
|
||||
reason: str
|
||||
message: str | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
|
||||
|
||||
class AgentBackendRunFailedInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal failure event carrying the backend-safe error text."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_FAILED] = AgentBackendInternalEventType.RUN_FAILED
|
||||
error: str
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class AgentBackendRunCancelledInternalEvent(AgentBackendInternalEventBase):
|
||||
"""API-internal terminal cancellation event."""
|
||||
|
||||
type: Literal[AgentBackendInternalEventType.RUN_CANCELLED] = AgentBackendInternalEventType.RUN_CANCELLED
|
||||
reason: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
|
||||
type AgentBackendInternalEvent = Annotated[
|
||||
AgentBackendRunStartedInternalEvent
|
||||
| AgentBackendStreamInternalEvent
|
||||
| AgentBackendRunPausedInternalEvent
|
||||
| AgentBackendRunSucceededInternalEvent
|
||||
| AgentBackendRunFailedInternalEvent
|
||||
| AgentBackendRunCancelledInternalEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class AgentBackendRunEventAdapter:
|
||||
"""Maps public ``dify-agent`` event variants to API-internal event variants."""
|
||||
|
||||
def adapt(self, event: RunEvent) -> list[AgentBackendInternalEvent]:
|
||||
"""Return zero or more API-internal events derived from one public run event."""
|
||||
match event:
|
||||
case RunStartedEvent():
|
||||
return [
|
||||
AgentBackendRunStartedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
)
|
||||
]
|
||||
case PydanticAIStreamRunEvent():
|
||||
data = cast(JsonValue, _EVENT_DATA_ADAPTER.dump_python(event.data, mode="json"))
|
||||
event_kind = data.get("event_kind") if isinstance(data, dict) else None
|
||||
return [
|
||||
AgentBackendStreamInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
event_kind=event_kind if isinstance(event_kind, str) else None,
|
||||
data=data,
|
||||
)
|
||||
]
|
||||
case RunSucceededEvent():
|
||||
return [
|
||||
AgentBackendRunSucceededInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
output=event.data.output,
|
||||
session_snapshot=event.data.session_snapshot,
|
||||
)
|
||||
]
|
||||
case RunPausedEvent():
|
||||
return [
|
||||
AgentBackendRunPausedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
reason=event.data.reason,
|
||||
message=event.data.message,
|
||||
session_snapshot=event.data.session_snapshot,
|
||||
)
|
||||
]
|
||||
case RunFailedEvent():
|
||||
return [
|
||||
AgentBackendRunFailedInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
error=event.data.error,
|
||||
reason=event.data.reason,
|
||||
)
|
||||
]
|
||||
case RunCancelledEvent():
|
||||
return [
|
||||
AgentBackendRunCancelledInternalEvent(
|
||||
run_id=event.run_id,
|
||||
source_event_id=event.id,
|
||||
reason=event.data.reason,
|
||||
message=event.data.message,
|
||||
)
|
||||
]
|
||||
raise TypeError(f"unsupported agent backend run event: {type(event).__name__}")
|
||||
22
api/clients/agent_backend/factory.py
Normal file
22
api/clients/agent_backend/factory.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Factories for API-side Agent backend clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dify_agent.client import Client
|
||||
|
||||
from clients.agent_backend.client import AgentBackendRunClient, DifyAgentBackendRunClient
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
|
||||
|
||||
def create_agent_backend_run_client(
|
||||
*,
|
||||
base_url: str | None = None,
|
||||
use_fake: bool = False,
|
||||
fake_scenario: str | FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
|
||||
) -> AgentBackendRunClient:
|
||||
"""Create the API-side run client without hiding the ``dify-agent`` protocol."""
|
||||
if use_fake:
|
||||
return FakeAgentBackendRunClient(scenario=FakeAgentBackendScenario(fake_scenario))
|
||||
if base_url is None:
|
||||
raise ValueError("base_url is required when creating a real Agent backend client")
|
||||
return DifyAgentBackendRunClient(Client(base_url=base_url))
|
||||
117
api/clients/agent_backend/fake_client.py
Normal file
117
api/clients/agent_backend/fake_client.py
Normal file
@ -0,0 +1,117 @@
|
||||
"""Deterministic fake Agent backend client using public ``dify-agent`` events.
|
||||
|
||||
Tests should exercise the same ``RunEvent`` DTOs as the real HTTP client. This
|
||||
fake therefore replaces the previous custom mock protocol instead of emulating a
|
||||
separate ``agent-backend.v1`` event stream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from dify_agent.protocol import (
|
||||
CancelRunRequest,
|
||||
CancelRunResponse,
|
||||
CreateRunRequest,
|
||||
CreateRunResponse,
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
RunSucceededEventData,
|
||||
)
|
||||
|
||||
_FIXED_TIME = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
|
||||
|
||||
class FakeAgentBackendScenario(StrEnum):
|
||||
"""Deterministic fake scenarios for API-side integration tests."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
"""In-memory implementation of ``AgentBackendRunClient`` for unit tests."""
|
||||
|
||||
scenario: FakeAgentBackendScenario
|
||||
run_id: str
|
||||
request: CreateRunRequest | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scenario: FakeAgentBackendScenario = FakeAgentBackendScenario.SUCCESS,
|
||||
run_id: str = "fake-run-1",
|
||||
) -> None:
|
||||
self.scenario = scenario
|
||||
self.run_id = run_id
|
||||
self.request = None
|
||||
|
||||
def create_run(self, request: CreateRunRequest) -> CreateRunResponse:
|
||||
"""Record the request and return a deterministic accepted response."""
|
||||
self.request = request
|
||||
return CreateRunResponse(run_id=self.run_id, status="running")
|
||||
|
||||
def cancel_run(self, run_id: str, request: CancelRunRequest | None = None) -> CancelRunResponse:
|
||||
"""Return a deterministic cancellation response."""
|
||||
del request
|
||||
return CancelRunResponse(run_id=run_id, status="cancelled")
|
||||
|
||||
def stream_events(self, run_id: str, *, after: str | None = None) -> Iterator[RunEvent]:
|
||||
"""Yield the deterministic public ``RunEvent`` sequence for ``run_id``."""
|
||||
for event in self._events(run_id):
|
||||
if after is not None and event.id is not None and event.id <= after:
|
||||
continue
|
||||
yield event
|
||||
|
||||
def wait_run(self, run_id: str, *, timeout_seconds: float | None = None) -> RunStatusResponse:
|
||||
"""Return a deterministic terminal status; timeout is accepted for protocol parity."""
|
||||
del timeout_seconds
|
||||
match self.scenario:
|
||||
case FakeAgentBackendScenario.SUCCESS:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="succeeded",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
case FakeAgentBackendScenario.FAILED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="failed",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
case FakeAgentBackendScenario.SUCCESS:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunSucceededEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunSucceededEventData(
|
||||
output={"text": "hello agent"},
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.FAILED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunFailedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
192
api/clients/agent_backend/request_builder.py
Normal file
192
api/clients/agent_backend/request_builder.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Build ``dify-agent`` run requests from API-side product concepts.
|
||||
|
||||
This module is intentionally an adapter, not a wire DTO package. The emitted
|
||||
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
|
||||
protocol has a single owner. API-only context such as Agent Soul vs workflow job
|
||||
prompt is preserved in layer names and metadata until the dedicated product
|
||||
schemas land in later phases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLayerConfig,
|
||||
DifyPluginLLMLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
ExecutionContext,
|
||||
LayerExitSignals,
|
||||
RunComposition,
|
||||
RunLayerSpec,
|
||||
RunPurpose,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
model_provider: str
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentBackendOutputConfig(BaseModel):
|
||||
"""API-side structured output declaration for the conventional output layer."""
|
||||
|
||||
json_schema: dict[str, JsonValue]
|
||||
name: str = "final_result"
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: ExecutionContext
|
||||
workflow_node_job_prompt: str
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "workflow_node"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
suspend_on_exit: bool = False
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("workflow_node_job_prompt", "user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "workflow_node_job"},
|
||||
config=PromptLayerConfig(prefix=run_input.workflow_node_job_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "workflow_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLayerConfig(
|
||||
tenant_id=run_input.model.tenant_id,
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
user_id=run_input.model.user_id,
|
||||
),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
name=run_input.output.name,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
execution_context=run_input.execution_context,
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_SENSITIVE_KEY_PARTS = ("secret", "credential", "token", "password", "api_key")
|
||||
|
||||
|
||||
def redact_for_agent_backend_log(value: object) -> object:
|
||||
"""Return a JSON-like copy with credential-bearing keys redacted for logs/tests."""
|
||||
if isinstance(value, BaseModel):
|
||||
return redact_for_agent_backend_log(value.model_dump(mode="json", warnings=False))
|
||||
if isinstance(value, dict):
|
||||
redacted: dict[object, object] = {}
|
||||
for key, item in value.items():
|
||||
key_text = str(key).lower()
|
||||
if any(part in key_text for part in _SENSITIVE_KEY_PARTS):
|
||||
redacted[key] = "[REDACTED]"
|
||||
else:
|
||||
redacted[key] = redact_for_agent_backend_log(item)
|
||||
return redacted
|
||||
if isinstance(value, list):
|
||||
return [redact_for_agent_backend_log(item) for item in value]
|
||||
return value
|
||||
@ -4,7 +4,6 @@ CLI command modules extracted from `commands.py`.
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .plugin import (
|
||||
backfill_plugin_auto_upgrade,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
install_plugins,
|
||||
@ -15,7 +14,6 @@ from .plugin import (
|
||||
setup_system_trigger_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
)
|
||||
from .rbac import migrate_member_roles_to_rbac
|
||||
from .retention import (
|
||||
archive_workflow_runs,
|
||||
clean_expired_messages,
|
||||
@ -39,7 +37,6 @@ from .vector import (
|
||||
__all__ = [
|
||||
"add_qdrant_index",
|
||||
"archive_workflow_runs",
|
||||
"backfill_plugin_auto_upgrade",
|
||||
"clean_expired_messages",
|
||||
"clean_workflow_runs",
|
||||
"cleanup_orphaned_draft_variables",
|
||||
@ -58,7 +55,6 @@ __all__ = [
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_member_roles_to_rbac",
|
||||
"migrate_oss",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import click
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
|
||||
from configs import dify_config
|
||||
@ -15,13 +14,11 @@ from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
@ -188,9 +185,9 @@ def transform_datasource_credentials(environment: str):
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
@ -405,110 +402,6 @@ def migrate_data_for_plugin():
|
||||
click.echo(click.style("Migrate data for plugin completed.", fg="green"))
|
||||
|
||||
|
||||
def _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit: int | None = None):
|
||||
category_count = len(TenantPluginAutoUpgradeStrategy.PluginCategory)
|
||||
stmt = (
|
||||
select(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
.group_by(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
.having(func.count(func.distinct(TenantPluginAutoUpgradeStrategy.category)) < category_count)
|
||||
.order_by(TenantPluginAutoUpgradeStrategy.tenant_id)
|
||||
)
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def _count_auto_upgrade_strategy_tenant_ids(limit: int | None) -> int:
|
||||
candidate_stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).subquery()
|
||||
return db.session.scalar(select(func.count()).select_from(candidate_stmt)) or 0
|
||||
|
||||
|
||||
def _iter_auto_upgrade_strategy_tenant_ids(limit: int | None):
|
||||
stmt = _candidate_auto_upgrade_strategy_tenant_ids_stmt(limit).execution_options(yield_per=1000)
|
||||
yield from db.session.scalars(stmt)
|
||||
|
||||
|
||||
@click.command(
|
||||
"backfill-plugin-auto-upgrade",
|
||||
help="Backfill category-scoped plugin auto-upgrade strategies and normalize plugin lists.",
|
||||
)
|
||||
@click.option("--tenant-id", multiple=True, help="Tenant ID to backfill. Can be passed multiple times.")
|
||||
@click.option("--limit", type=int, default=None, help="Maximum number of candidate tenants to process.")
|
||||
@click.option("--batch-size", type=int, default=500, show_default=True, help="Progress reporting batch size.")
|
||||
@click.option("--dry-run", is_flag=True, help="Only print candidate tenant count.")
|
||||
def backfill_plugin_auto_upgrade(
|
||||
tenant_id: tuple[str, ...],
|
||||
limit: int | None,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Backfill historical auto-upgrade strategies after the category column exists.
|
||||
|
||||
Missing category rows are created from the tenant's tool/default row. Pure default
|
||||
strategies become latest for model plugins and fix-only for all other categories.
|
||||
Tenants with include/exclude plugin IDs are split
|
||||
by installed plugin category using plugin daemon metadata.
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
candidate_count = len(tenant_id) if tenant_id else _count_auto_upgrade_strategy_tenant_ids(limit)
|
||||
click.echo(click.style(f"Found {candidate_count} candidate tenants.", fg="yellow"))
|
||||
|
||||
if dry_run:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(click.style(f"Dry run completed. elapsed={elapsed:.2f}s", fg="green"))
|
||||
return
|
||||
|
||||
tenant_ids = list(tenant_id) if tenant_id else _iter_auto_upgrade_strategy_tenant_ids(limit)
|
||||
|
||||
backfilled_count = 0
|
||||
created_count = 0
|
||||
normalized_count = 0
|
||||
skipped_count = 0
|
||||
failed_count = 0
|
||||
for index, current_tenant_id in enumerate(tenant_ids, start=1):
|
||||
try:
|
||||
result = PluginAutoUpgradeService.backfill_strategy_categories(
|
||||
current_tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
click.echo(click.style(f"Failed tenant {current_tenant_id}: {str(e)}", fg="red"))
|
||||
continue
|
||||
|
||||
if result.created_count > 0:
|
||||
backfilled_count += 1
|
||||
created_count += result.created_count
|
||||
elif not result.normalized:
|
||||
skipped_count += 1
|
||||
if result.normalized:
|
||||
normalized_count += 1
|
||||
|
||||
if batch_size > 0 and index % batch_size == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Processed {index}/{candidate_count} tenants. "
|
||||
f"backfilled={backfilled_count}, created_rows={created_count}, "
|
||||
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
|
||||
f"elapsed={time.perf_counter() - start_at:.2f}s",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Backfill plugin auto-upgrade strategy categories completed. "
|
||||
f"backfilled={backfilled_count}, created_rows={created_count}, "
|
||||
f"normalized={normalized_count}, skipped={skipped_count}, failed={failed_count}, "
|
||||
f"elapsed={elapsed:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command("extract-plugins", help="Extract plugins.")
|
||||
@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl")
|
||||
@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10)
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import click
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import TenantAccountJoin, TenantAccountRole
|
||||
from services.enterprise.rbac_service import ListOption, RBACService
|
||||
|
||||
|
||||
def _resolve_builtin_role_id(tenant_id: str, operator_account_id: str, legacy_role: str) -> str:
|
||||
"""Resolve a legacy workspace role to the current tenant's builtin RBAC role id.
|
||||
|
||||
The migration replays the old `TenantAccountJoin.role` values onto the
|
||||
RBAC member-role binding API. Builtin RBAC roles are tenant-scoped and
|
||||
identified by runtime ids, so the command must look them up per tenant.
|
||||
"""
|
||||
expected_builtin_name = {
|
||||
TenantAccountRole.OWNER.value: "所有者",
|
||||
TenantAccountRole.ADMIN.value: "管理者",
|
||||
TenantAccountRole.EDITOR.value: "编辑者",
|
||||
TenantAccountRole.NORMAL.value: "普通用户",
|
||||
TenantAccountRole.DATASET_OPERATOR.value: "知识库操作员",
|
||||
}.get(legacy_role)
|
||||
if not expected_builtin_name:
|
||||
raise ValueError(f"Unsupported legacy workspace role: {legacy_role}")
|
||||
|
||||
roles = RBACService.Roles.list(
|
||||
tenant_id=tenant_id,
|
||||
account_id=operator_account_id,
|
||||
options=ListOption(page_number=1, results_per_page=100),
|
||||
).data
|
||||
for role in roles:
|
||||
if role.is_builtin and role.category == "global_system_default" and role.name == expected_builtin_name:
|
||||
return str(role.id)
|
||||
|
||||
raise ValueError(f"Builtin RBAC role not found for tenant={tenant_id}, legacy_role={legacy_role}")
|
||||
|
||||
|
||||
@click.command("rbac-migrate-member-roles", help="Migrate legacy workspace member roles into RBAC member-role bindings.")
|
||||
@click.option("--tenant-id", help="Only migrate a single workspace.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Preview the migration without writing RBAC bindings.")
|
||||
def migrate_member_roles_to_rbac(tenant_id: str | None, dry_run: bool) -> None:
|
||||
"""Backfill RBAC member-role bindings from legacy `TenantAccountJoin.role` data.
|
||||
|
||||
This is an offline migration command for workspaces that already have
|
||||
members in the legacy role model but need matching records in the RBAC
|
||||
member-role binding store.
|
||||
"""
|
||||
click.echo(click.style("Starting RBAC member-role migration.", fg="green"))
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantAccountJoin).order_by(TenantAccountJoin.tenant_id.asc(), TenantAccountJoin.id.asc())
|
||||
if tenant_id:
|
||||
stmt = stmt.where(TenantAccountJoin.tenant_id == tenant_id)
|
||||
|
||||
joins = list(session.scalars(stmt).all())
|
||||
|
||||
if not joins:
|
||||
click.echo(click.style("No workspace members found for migration.", fg="yellow"))
|
||||
return
|
||||
|
||||
owner_account_by_tenant: dict[str, str] = {}
|
||||
resolved_role_ids: dict[tuple[str, str], str] = {}
|
||||
migrated_count = 0
|
||||
|
||||
for join in joins:
|
||||
workspace_id = str(join.tenant_id)
|
||||
member_account_id = str(join.account_id)
|
||||
legacy_role = str(join.role)
|
||||
|
||||
if workspace_id not in owner_account_by_tenant:
|
||||
owner_join = next(
|
||||
(
|
||||
item
|
||||
for item in joins
|
||||
if str(item.tenant_id) == workspace_id and str(item.role) == TenantAccountRole.OWNER.value
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not owner_join:
|
||||
raise ValueError(f"Workspace owner not found for tenant={workspace_id}")
|
||||
owner_account_by_tenant[workspace_id] = str(owner_join.account_id)
|
||||
|
||||
operator_account_id = owner_account_by_tenant[workspace_id]
|
||||
cache_key = (workspace_id, legacy_role)
|
||||
if cache_key not in resolved_role_ids:
|
||||
resolved_role_ids[cache_key] = _resolve_builtin_role_id(workspace_id, operator_account_id, legacy_role)
|
||||
|
||||
resolved_role_id = resolved_role_ids[cache_key]
|
||||
click.echo(
|
||||
f"tenant={workspace_id} member={member_account_id} legacy_role={legacy_role} -> rbac_role_id={resolved_role_id}"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
RBACService.MemberRoles.replace(
|
||||
tenant_id=workspace_id,
|
||||
account_id=operator_account_id,
|
||||
member_account_id=member_account_id,
|
||||
role_ids=[resolved_role_id],
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
if dry_run:
|
||||
click.echo(click.style("Dry run completed. No RBAC bindings were written.", fg="yellow"))
|
||||
else:
|
||||
click.echo(click.style(f"RBAC member-role migration completed. Migrated {migrated_count} members.", fg="green"))
|
||||
@ -14,6 +14,7 @@ from libs.rsa import generate_key_pair
|
||||
from models import Tenant
|
||||
from models.model import App, AppMode, Conversation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,13 +24,16 @@ DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
@click.command(
|
||||
"reset-encrypt-key-pair",
|
||||
help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. "
|
||||
"After the reset, all LLM credentials will become invalid, "
|
||||
"requiring re-entry."
|
||||
"After the reset, all LLM credentials and tool provider credentials "
|
||||
"(builtin / API / MCP) will be purged, requiring re-entry. "
|
||||
"Only support SELF_HOSTED mode.",
|
||||
)
|
||||
@click.confirmation_option(
|
||||
prompt=click.style(
|
||||
"Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red"
|
||||
"Are you sure you want to reset encrypt key pair? "
|
||||
"This will also purge builtin / API / MCP tool provider records for every tenant. "
|
||||
"This operation cannot be rolled back!",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
def reset_encrypt_key_pair():
|
||||
@ -53,6 +57,13 @@ def reset_encrypt_key_pair():
|
||||
session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id))
|
||||
session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id))
|
||||
|
||||
# Purge tool provider records that hold credentials encrypted under the
|
||||
# tenant key. Leaving them in place causes /console/api/workspaces/current/
|
||||
# tool-providers to 500 because decryption fails on stale ciphertext (#35396).
|
||||
session.execute(delete(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant.id))
|
||||
session.execute(delete(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant.id))
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.",
|
||||
|
||||
@ -23,9 +23,10 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
RBAC_ENABLED: bool = Field(
|
||||
description="Enable enterprise RBAC APIs. When disabled, compatibility responses fall back to legacy roles.",
|
||||
ENTERPRISE_DISABLE_RUNTIME_CREDENTIAL_CHECK: bool = Field(
|
||||
default=False,
|
||||
description="If disabled, credential policy check is only performed when saving workflows."
|
||||
"This helps gain runtime performance by trading off consistency.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
@ -1406,32 +1406,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class EvaluationConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for evaluation runtime
|
||||
"""
|
||||
|
||||
EVALUATION_FRAMEWORK: str = Field(
|
||||
description="Evaluation framework to use (ragas/deepeval/none)",
|
||||
default="none",
|
||||
)
|
||||
|
||||
EVALUATION_MAX_CONCURRENT_RUNS: PositiveInt = Field(
|
||||
description="Maximum number of concurrent evaluation runs per tenant",
|
||||
default=3,
|
||||
)
|
||||
|
||||
EVALUATION_MAX_DATASET_ROWS: PositiveInt = Field(
|
||||
description="Maximum number of rows allowed in an evaluation dataset",
|
||||
default=500,
|
||||
)
|
||||
|
||||
EVALUATION_TASK_TIMEOUT: PositiveInt = Field(
|
||||
description="Timeout in seconds for a single evaluation task",
|
||||
default=3600,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1445,7 +1419,6 @@ class FeatureConfig(
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
EvaluationConfig,
|
||||
FileAccessConfig,
|
||||
FileUploadConfig,
|
||||
HttpConfig,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
@ -50,28 +50,30 @@ from .vdb.vastbase_vector_config import VastbaseVectorConfig
|
||||
from .vdb.vikingdb_config import VikingDBConfig
|
||||
from .vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
_VALID_STORAGE_TYPE = Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
]
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: Literal[
|
||||
"opendal",
|
||||
"s3",
|
||||
"aliyun-oss",
|
||||
"azure-blob",
|
||||
"baidu-obs",
|
||||
"clickzetta-volume",
|
||||
"google-storage",
|
||||
"huawei-obs",
|
||||
"oci-storage",
|
||||
"tencent-cos",
|
||||
"volcengine-tos",
|
||||
"supabase",
|
||||
"local",
|
||||
] = Field(
|
||||
STORAGE_TYPE: _VALID_STORAGE_TYPE = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', "
|
||||
"'clickzetta-volume', 'google-storage', 'huawei-obs', 'oci-storage', 'tencent-cos', "
|
||||
"'volcengine-tos', 'supabase'. Default is 'opendal'.",
|
||||
default="opendal",
|
||||
default=cast(_VALID_STORAGE_TYPE, "opendal"),
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
|
||||
91
api/conftest.py
Normal file
91
api/conftest.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""Global pytest hooks for Dify backend tests.
|
||||
|
||||
This root conftest is loaded before package-specific conftests, which lets tests opt
|
||||
into Docker-backed middleware before application modules read environment config.
|
||||
It intentionally lives at the API root because pytest applies conftest.py files to
|
||||
tests below their directory, and this setup is shared by api/tests and api/providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.pytest_dify import (
|
||||
DEFAULT_MIDDLEWARE_SERVICES,
|
||||
DEFAULT_VDB_SERVICES,
|
||||
DockerComposeStack,
|
||||
build_middleware_stack,
|
||||
build_vdb_stack,
|
||||
ensure_backend_test_environment,
|
||||
ensure_compose_env_files,
|
||||
parse_services,
|
||||
)
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
_DIFY_COMPOSE_STACKS_KEY = pytest.StashKey[list[DockerComposeStack]]()
|
||||
|
||||
# This must run at import time because package-specific conftests can import the
|
||||
# Flask app before pytest_configure hooks from this file are called.
|
||||
ensure_backend_test_environment(_REPO_ROOT)
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
group = parser.getgroup("dify")
|
||||
group.addoption(
|
||||
"--start-middleware",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start the Docker middleware services needed by API integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--middleware-services",
|
||||
default=",".join(DEFAULT_MIDDLEWARE_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.middleware.yaml to start.",
|
||||
)
|
||||
group.addoption(
|
||||
"--start-vdb",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Start vector-store Docker services for VDB integration tests.",
|
||||
)
|
||||
group.addoption(
|
||||
"--vdb-services",
|
||||
default=",".join(DEFAULT_VDB_SERVICES),
|
||||
help="Comma-separated services from docker/docker-compose.yaml to start for VDB tests.",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = []
|
||||
|
||||
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
config = session.config
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks: list[DockerComposeStack] = []
|
||||
if config.getoption("start_middleware"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_middleware_stack(_REPO_ROOT, parse_services(config.getoption("middleware_services")))
|
||||
stack.up()
|
||||
stacks.append(stack)
|
||||
|
||||
if config.getoption("start_vdb"):
|
||||
ensure_compose_env_files(_REPO_ROOT)
|
||||
stack = build_vdb_stack(_REPO_ROOT, parse_services(config.getoption("vdb_services")))
|
||||
stack.up()
|
||||
stacks.append(stack)
|
||||
|
||||
config.stash[_DIFY_COMPOSE_STACKS_KEY] = stacks
|
||||
|
||||
|
||||
def pytest_unconfigure(config: pytest.Config) -> None:
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
stacks = config.stash.get(_DIFY_COMPOSE_STACKS_KEY, [])
|
||||
for stack in reversed(stacks):
|
||||
stack.down()
|
||||
@ -34,6 +34,7 @@ from controllers.common.schema import (
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
```
|
||||
|
||||
Register request payload and query models with `register_schema_models(...)`:
|
||||
@ -82,7 +83,7 @@ register_schema_models(console_ns, DraftWorkflowNodeRunPayload)
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
result = service.run(..., inputs=payload.inputs, query=payload.query)
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(WorkflowRunNodeExecutionResponse, result)
|
||||
```
|
||||
|
||||
## Query Parameters
|
||||
@ -105,7 +106,7 @@ class WorkflowRunListQuery(BaseModel):
|
||||
def get(self, app_model: App):
|
||||
query = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
result = service.list(..., limit=query.limit, last_id=query.last_id)
|
||||
return WorkflowRunPaginationResponse.model_validate(result, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(WorkflowRunPaginationResponse, result)
|
||||
```
|
||||
|
||||
Do not do this for GET query parameters:
|
||||
@ -145,10 +146,25 @@ def post(...):
|
||||
Serialize explicitly:
|
||||
|
||||
```python
|
||||
return WorkflowRunNodeExecutionResponse.model_validate(
|
||||
workflow_node_execution,
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
return dump_response(WorkflowRunNodeExecutionResponse, workflow_node_execution)
|
||||
```
|
||||
|
||||
`dump_response(...)` is the preferred response serialization helper for a single Pydantic response DTO. It validates
|
||||
with `from_attributes=True` and returns `model_dump(mode="json")`, so SQLAlchemy models, plain objects, dictionaries,
|
||||
Pydantic aliases, computed fields, and `datetime` values are serialized consistently.
|
||||
|
||||
For wrapper responses, pass a dictionary with the public wrapper fields:
|
||||
|
||||
```python
|
||||
return dump_response(
|
||||
WorkflowRunPaginationResponse,
|
||||
{
|
||||
"data": workflow_runs,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
If the service can return `None`, translate that into the expected HTTP error before validation:
|
||||
@ -158,9 +174,12 @@ workflow_run = service.get_workflow_run(...)
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return WorkflowRunDetailResponse.model_validate(workflow_run, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(WorkflowRunDetailResponse, workflow_run)
|
||||
```
|
||||
|
||||
Use manual `model_validate(...).model_dump(...)` only when the endpoint needs behavior that `dump_response(...)` does
|
||||
not provide, such as returning a non-dict payload, intentionally excluding fields, or composing a `(body, status)` tuple.
|
||||
|
||||
## Legacy Flask-RESTX Patterns
|
||||
|
||||
Avoid adding these patterns to new or migrated endpoints:
|
||||
@ -190,4 +209,3 @@ Inspect affected endpoints with `jq`. Check that:
|
||||
- Request bodies appear only where the endpoint has a body.
|
||||
- Responses reference the expected `*Response` schema.
|
||||
- Response schemas use public serialized names, not internal validation aliases like `inputs_dict`.
|
||||
|
||||
|
||||
@ -2,8 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import IconType
|
||||
|
||||
@ -19,6 +20,113 @@ class SystemParameters(BaseModel):
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class SimpleResultResponse(ResponseModel):
|
||||
result: str
|
||||
|
||||
|
||||
class SimpleResultMessageResponse(ResponseModel):
|
||||
result: str
|
||||
message: str
|
||||
|
||||
|
||||
class SimpleMessageResponse(ResponseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class SimpleDataResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
class SimpleResultDataResponse(ResponseModel):
|
||||
result: str
|
||||
data: str
|
||||
|
||||
|
||||
class SimpleResultStringListResponse(ResponseModel):
|
||||
result: str
|
||||
data: list[str]
|
||||
|
||||
|
||||
class SimpleResultOptionalDataResponse(ResponseModel):
|
||||
result: str
|
||||
data: str | None = None
|
||||
|
||||
|
||||
class AccessTokenData(ResponseModel):
|
||||
access_token: str
|
||||
|
||||
|
||||
class AccessTokenResultResponse(ResponseModel):
|
||||
result: str
|
||||
data: AccessTokenData
|
||||
|
||||
|
||||
class VerificationTokenResponse(ResponseModel):
|
||||
is_valid: bool
|
||||
email: str
|
||||
token: str
|
||||
|
||||
|
||||
class LoginStatusResponse(ResponseModel):
|
||||
logged_in: bool
|
||||
app_logged_in: bool
|
||||
|
||||
|
||||
class AccessModeResponse(ResponseModel):
|
||||
access_mode: str = Field(serialization_alias="accessMode", validation_alias="accessMode")
|
||||
|
||||
|
||||
class BooleanResultResponse(ResponseModel):
|
||||
result: bool
|
||||
|
||||
|
||||
class SuccessResponse(ResponseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class UsageCheckResponse(ResponseModel):
|
||||
is_using: bool
|
||||
|
||||
|
||||
class UsageCountResponse(ResponseModel):
|
||||
is_using: bool
|
||||
count: int
|
||||
|
||||
|
||||
class IndexInfoResponse(ResponseModel):
|
||||
welcome: str
|
||||
api_version: str
|
||||
server_version: str
|
||||
|
||||
|
||||
class AvatarUrlResponse(ResponseModel):
|
||||
avatar_url: str
|
||||
|
||||
|
||||
class TextContentResponse(ResponseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class AllowedExtensionsResponse(ResponseModel):
|
||||
allowed_extensions: list[str]
|
||||
|
||||
|
||||
class UrlResponse(ResponseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class RedirectUrlResponse(ResponseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
class ApiBaseUrlResponse(ResponseModel):
|
||||
api_base_url: str
|
||||
|
||||
|
||||
class NewAppResponse(ResponseModel):
|
||||
new_app_id: str
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
opening_statement: str | None = None
|
||||
suggested_questions: list[str]
|
||||
|
||||
@ -1,36 +1,21 @@
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
import json
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
|
||||
|
||||
def stringify_form_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
"""Serialize default values into strings expected by human-input form clients."""
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
@ -39,6 +39,7 @@ QueryParamDoc = TypedDict(
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
|
||||
schema = _swagger_2_compatible_schema(schema)
|
||||
nested_definitions = schema.get("$defs")
|
||||
schema_to_register = dict(schema)
|
||||
if isinstance(nested_definitions, dict):
|
||||
@ -65,6 +66,35 @@ def _register_schema_model(namespace: Namespace, model: type[BaseModel], *, mode
|
||||
)
|
||||
|
||||
|
||||
def _swagger_2_compatible_schema(value: Any) -> Any:
|
||||
if isinstance(value, list):
|
||||
return [_swagger_2_compatible_schema(item) for item in value]
|
||||
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
converted = {key: _swagger_2_compatible_schema(child) for key, child in value.items()}
|
||||
any_of = value.get("anyOf")
|
||||
if not isinstance(any_of, list):
|
||||
return converted
|
||||
|
||||
non_null_candidates = [
|
||||
candidate for candidate in any_of if isinstance(candidate, Mapping) and candidate.get("type") != "null"
|
||||
]
|
||||
has_null_candidate = any(isinstance(candidate, Mapping) and candidate.get("type") == "null" for candidate in any_of)
|
||||
if not has_null_candidate or len(non_null_candidates) != 1:
|
||||
return converted
|
||||
|
||||
non_null_schema = _swagger_2_compatible_schema(dict(non_null_candidates[0]))
|
||||
if not isinstance(non_null_schema, dict):
|
||||
return converted
|
||||
|
||||
converted.pop("anyOf", None)
|
||||
converted.update(non_null_schema)
|
||||
converted["x-nullable"] = True
|
||||
return converted
|
||||
|
||||
|
||||
def register_schema_model(namespace: Namespace, model: type[BaseModel]) -> None:
|
||||
"""Register a BaseModel and its nested schema definitions for Swagger documentation."""
|
||||
|
||||
|
||||
@ -33,7 +33,6 @@ for module_name in RESOURCE_MODULES:
|
||||
# Ensure resource modules are imported so route decorators are evaluated.
|
||||
# Import other controllers
|
||||
from . import (
|
||||
admin,
|
||||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
@ -45,6 +44,8 @@ from . import (
|
||||
spec,
|
||||
version,
|
||||
)
|
||||
from .agent import composer as agent_composer
|
||||
from .agent import roster as agent_roster
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
@ -108,9 +109,6 @@ from .datasets.rag_pipeline import (
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
banner,
|
||||
@ -120,12 +118,8 @@ from .explore import (
|
||||
saved_message,
|
||||
trial,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow, snippet_workflow_draft_variable
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import snippet controllers
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@ -139,8 +133,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
rbac,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@ -151,10 +143,11 @@ api.add_namespace(console_ns)
|
||||
__all__ = [
|
||||
"account",
|
||||
"activate",
|
||||
"admin",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
"annotation",
|
||||
"api",
|
||||
"apikey",
|
||||
@ -178,7 +171,6 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@ -209,17 +201,10 @@ __all__ = [
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"rbac",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippet_workflow_draft_variable",
|
||||
"snippets",
|
||||
"snippets",
|
||||
"socketio_workflow",
|
||||
"spec",
|
||||
"statistic",
|
||||
|
||||
@ -1,64 +1,11 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
|
||||
|
||||
class InsertExploreAppPayload(BaseModel):
|
||||
app_id: str = Field(...)
|
||||
desc: str | None = None
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
language: str = Field(...)
|
||||
category: str = Field(...)
|
||||
position: int = Field(...)
|
||||
can_trial: bool = Field(default=False)
|
||||
trial_limit: int = Field(default=0)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
|
||||
class InsertExploreBannerPayload(BaseModel):
|
||||
category: str = Field(...)
|
||||
title: str = Field(...)
|
||||
description: str = Field(...)
|
||||
img_src: str = Field(..., alias="img-src")
|
||||
language: str = Field(default="en-US")
|
||||
link: str = Field(...)
|
||||
sort: int = Field(...)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def validate_language(cls, value: str) -> str:
|
||||
return supported_language(value)
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
|
||||
register_schema_models(console_ns, InsertExploreAppPayload, InsertExploreBannerPayload)
|
||||
|
||||
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@ -76,353 +23,3 @@ def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(console_ns.models[InsertExploreAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == payload.app_id)).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{payload.app_id}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
desc = payload.desc or ""
|
||||
copy_right = payload.copyright or ""
|
||||
privacy_policy = payload.privacy_policy or ""
|
||||
custom_disclaimer = payload.custom_disclaimer or ""
|
||||
else:
|
||||
desc = site.description or payload.desc or ""
|
||||
copy_right = site.copyright or payload.copyright or ""
|
||||
privacy_policy = site.privacy_policy or payload.privacy_policy or ""
|
||||
custom_disclaimer = site.custom_disclaimer or payload.custom_disclaimer or ""
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
app_id=app.id,
|
||||
description=desc,
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=payload.language,
|
||||
category=payload.category,
|
||||
position=payload.position,
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = payload.language
|
||||
recommended_app.category = payload.category
|
||||
recommended_app.position = payload.position
|
||||
|
||||
if payload.can_trial:
|
||||
trial_app = db.session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == payload.app_id)
|
||||
).scalar_one_or_none()
|
||||
if not trial_app:
|
||||
db.session.add(
|
||||
TrialApp(
|
||||
app_id=payload.app_id,
|
||||
tenant_id=app.tenant_id,
|
||||
trial_limit=payload.trial_limit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
trial_app.trial_limit = payload.trial_limit
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id: UUID):
|
||||
with session_factory.create_session() as session:
|
||||
recommended_app = session.execute(
|
||||
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
return {"result": "success"}, 204
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
|
||||
|
||||
if app:
|
||||
app.is_public = False
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
|
||||
trial_app = session.execute(
|
||||
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
|
||||
).scalar_one_or_none()
|
||||
if trial_app:
|
||||
session.delete(trial_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/admin/insert-explore-banner")
|
||||
class InsertExploreBannerApi(Resource):
|
||||
@console_ns.doc("insert_explore_banner")
|
||||
@console_ns.doc(description="Insert an explore banner")
|
||||
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
|
||||
@console_ns.response(201, "Banner inserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
|
||||
|
||||
banner = ExporleBanner(
|
||||
content={
|
||||
"category": payload.category,
|
||||
"title": payload.title,
|
||||
"description": payload.description,
|
||||
"img-src": payload.img_src,
|
||||
},
|
||||
link=payload.link,
|
||||
sort=payload.sort,
|
||||
language=payload.language,
|
||||
)
|
||||
db.session.add(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
|
||||
class DeleteExploreBannerApi(Resource):
|
||||
@console_ns.doc("delete_explore_banner")
|
||||
@console_ns.doc(description="Delete an explore banner")
|
||||
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
|
||||
@console_ns.response(204, "Banner deleted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, banner_id):
|
||||
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
|
||||
if not banner:
|
||||
raise NotFound(f"Banner '{banner_id}' is not found")
|
||||
|
||||
db.session.delete(banner)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
register_schema_models(console_ns, UpsertNotificationPayload, BatchAddNotificationAccountsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.stream.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.stream.seek(0)
|
||||
content = file.stream.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
3
api/controllers/console/agent/__init__.py
Normal file
3
api/controllers/console/agent/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from . import composer, roster
|
||||
|
||||
__all__ = ["composer", "roster"]
|
||||
153
api/controllers/console/agent/composer.py
Normal file
153
api/controllers/console/agent/composer.py
Normal file
@ -0,0 +1,153 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
|
||||
class WorkflowAgentComposerApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
|
||||
class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
|
||||
class WorkflowAgentComposerImpactApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
|
||||
class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
|
||||
class AgentAppComposerApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
|
||||
class AgentAppComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def post(self, app_model):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
|
||||
class AgentAppComposerCandidatesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
130
api/controllers/console/agent/roster.py
Normal file
130
api/controllers/console/agent/roster.py
Normal file
@ -0,0 +1,130 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
|
||||
class AgentInviteOptionsQuery(RosterListQuery):
|
||||
app_id: str | None = Field(default=None, description="Workflow app id for in-current-workflow markers")
|
||||
|
||||
|
||||
class AgentIdPath(BaseModel):
|
||||
agent_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AgentInviteOptionsQuery,
|
||||
AgentIdPath,
|
||||
RosterAgentCreatePayload,
|
||||
RosterAgentUpdatePayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
|
||||
|
||||
def _agent_roster_service() -> AgentRosterService:
|
||||
return AgentRosterService(db.session)
|
||||
|
||||
|
||||
@console_ns.route("/agents")
|
||||
class AgentRosterListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
|
||||
|
||||
|
||||
@console_ns.route("/agents/invite-options")
|
||||
class AgentInviteOptionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>")
|
||||
class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions")
|
||||
class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
|
||||
class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id, version_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
)
|
||||
@ -11,6 +11,7 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
@ -21,12 +22,6 @@ from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
id: str
|
||||
type: str
|
||||
@ -37,7 +32,7 @@ class ApiKeyItem(ResponseModel):
|
||||
@field_validator("last_used_at", "created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ApiKeyList(ResponseModel):
|
||||
|
||||
@ -3,7 +3,6 @@ import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -13,8 +12,9 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
@ -31,19 +31,16 @@ from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from models.workflow import resolve_workflow_kind
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
@ -181,12 +178,6 @@ class AppTracePayload(BaseModel):
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class Tag(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -203,7 +194,7 @@ class WorkflowPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfigPartial(ResponseModel):
|
||||
@ -217,7 +208,7 @@ class ModelConfigPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
@ -278,7 +269,7 @@ class ModelConfig(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class Site(ResponseModel):
|
||||
@ -321,7 +312,7 @@ class Site(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DeletedTool(ResponseModel):
|
||||
@ -355,9 +346,6 @@ class AppPartial(ResponseModel):
|
||||
create_user_name: str | None = None
|
||||
author_name: str | None = None
|
||||
has_draft_trigger: bool | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
permission_keys: list[str] = Field(default_factory=list)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -367,7 +355,7 @@ class AppPartial(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetail(ResponseModel):
|
||||
@ -392,15 +380,12 @@ class AppDetail(ResponseModel):
|
||||
updated_by: str | None = None
|
||||
updated_at: int | None = None
|
||||
access_mode: str | None = None
|
||||
workflow_type: str | None = None
|
||||
workflow_kind: str | None = None
|
||||
tags: list[Tag] = Field(default_factory=list)
|
||||
permission_keys: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AppDetailWithSite(AppDetail):
|
||||
@ -428,23 +413,8 @@ class AppExportResponse(ResponseModel):
|
||||
data: str
|
||||
|
||||
|
||||
def _collect_app_access_permission_keys(access_matrix: enterprise_rbac_service.AppAccessMatrix) -> list[str]:
|
||||
permission_keys: list[str] = []
|
||||
seen_permission_keys: set[str] = set()
|
||||
|
||||
for item in access_matrix.items:
|
||||
if not item.policy:
|
||||
continue
|
||||
for permission_key in item.policy.permission_keys:
|
||||
if permission_key in seen_permission_keys:
|
||||
continue
|
||||
seen_permission_keys.add(permission_key)
|
||||
permission_keys.append(permission_key)
|
||||
|
||||
return permission_keys
|
||||
|
||||
|
||||
register_enum_models(console_ns, RetrievalMethod, WorkflowExecutionStatus, DatasetPermissionEnum)
|
||||
register_response_schema_models(console_ns, RedirectUrlResponse, SimpleResultResponse)
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
@ -529,20 +499,6 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
if app_pagination.items:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
app_ids = [str(app.id) for app in app_pagination.items]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.AppPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
app_ids,
|
||||
)
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = permission_keys_map.get(str(app.id), [])
|
||||
else:
|
||||
for app in app_pagination.items:
|
||||
app.permission_keys = []
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
@ -574,25 +530,6 @@ class AppListApi(Resource):
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
workflow_ids = [str(app.workflow_id) for app in app_pagination.items if app.workflow_id]
|
||||
workflow_info_map: dict[str, tuple[str, str]] = {}
|
||||
if workflow_ids:
|
||||
rows = db.session.execute(
|
||||
select(Workflow.id, Workflow.type, Workflow.kind).where(Workflow.id.in_(workflow_ids))
|
||||
).all()
|
||||
workflow_info_map = {
|
||||
str(row.id): (
|
||||
row.type.value if hasattr(row.type, "value") else str(row.type),
|
||||
resolve_workflow_kind(row.kind).value,
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
|
||||
for app in app_pagination.items:
|
||||
workflow_info = workflow_info_map.get(str(app.workflow_id)) if app.workflow_id else None
|
||||
app.workflow_type = workflow_info[0] if workflow_info else None
|
||||
app.workflow_kind = workflow_info[1] if workflow_info else None
|
||||
|
||||
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
|
||||
return pagination_model.model_dump(mode="json"), 200
|
||||
|
||||
@ -639,7 +576,6 @@ class AppApi(Resource):
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
app_service = AppService()
|
||||
|
||||
app_model = app_service.get_app(app_model)
|
||||
@ -648,29 +584,6 @@ class AppApi(Resource):
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
|
||||
if app_model.workflow_id:
|
||||
row = db.session.execute(
|
||||
select(Workflow.type, Workflow.kind).where(Workflow.id == app_model.workflow_id)
|
||||
).first()
|
||||
app_model.workflow_type = (
|
||||
(row.type.value if hasattr(row.type, "value") else str(row.type)) if row else None
|
||||
)
|
||||
app_model.workflow_kind = resolve_workflow_kind(row.kind).value if row else None
|
||||
else:
|
||||
app_model.workflow_type = None
|
||||
app_model.workflow_kind = None
|
||||
|
||||
if dify_config.RBAC_ENABLED:
|
||||
app_access_matrix = enterprise_rbac_service.RBACService.AppAccess.matrix(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
str(app_model.id),
|
||||
)
|
||||
app_model.permission_keys = _collect_app_access_permission_keys(app_access_matrix)
|
||||
else:
|
||||
app_model.permission_keys = []
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
|
||||
@ -813,6 +726,7 @@ class AppExportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RedirectUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -938,10 +852,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
"""Get app trace"""
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(str(app_id), session)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@ -949,18 +864,23 @@ class AppTraceApi(Resource):
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Trace configuration updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=str(app_id),
|
||||
app_id=app_model.id,
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
)
|
||||
|
||||
@ -7,7 +7,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -66,6 +67,7 @@ class ChatMessagePayload(BaseMessagePayload):
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@ -124,7 +126,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_completion_message")
|
||||
@console_ns.doc(description="Stop a running completion message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -205,7 +207,7 @@ class ChatMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_chat_message")
|
||||
@console_ns.doc(description="Stop a running chat message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -16,6 +16,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from extensions.ext_database import db
|
||||
from fields._value_type_serializer import serialize_value_type
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
@ -25,12 +26,6 @@ class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class ConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -65,7 +60,7 @@ class ConversationVariableResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class PaginatedConversationVariableResponse(ResponseModel):
|
||||
|
||||
@ -13,17 +13,12 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
@ -41,14 +36,14 @@ class AppMCPServerResponse(ResponseModel):
|
||||
name: str
|
||||
server_code: str
|
||||
description: str
|
||||
status: str
|
||||
status: AppMCPServerStatus
|
||||
parameters: dict[str, Any] | list[Any] | str
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def _parse_json_string(cls, value: Any) -> Any:
|
||||
def _normalize_parameters(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
@ -59,7 +54,7 @@ class AppMCPServerResponse(ResponseModel):
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
|
||||
@ -70,7 +65,9 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "MCP server configuration retrieved successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@ -85,7 +82,9 @@ class AppMCPServerController(Resource):
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
|
||||
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
201, "MCP server configuration created successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@ -111,13 +110,15 @@ class AppMCPServerController(Resource):
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
|
||||
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(
|
||||
200, "MCP server configuration updated successfully", console_ns.models[AppMCPServerResponse.__name__]
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@ -154,7 +155,7 @@ class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(200, "MCP server refreshed successfully", console_ns.models[AppMCPServerResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@setup_required
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -37,10 +38,9 @@ from fields.conversation_fields import (
|
||||
JSONValue,
|
||||
MessageFile,
|
||||
format_files_contained,
|
||||
to_timestamp,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
@ -144,9 +144,7 @@ class MessageDetailResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return to_timestamp(value)
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class MessageInfiniteScrollPaginationResponse(ResponseModel):
|
||||
@ -165,6 +163,7 @@ register_schema_models(
|
||||
MessageDetailResponse,
|
||||
MessageInfiniteScrollPaginationResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
@ -250,7 +249,7 @@ class MessageFeedbackApi(Resource):
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(200, "Feedback updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
@ -9,8 +8,10 @@ from werkzeug.exceptions import BadRequest
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.ops_service import OpsService
|
||||
|
||||
|
||||
@ -43,11 +44,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id: UUID):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@get_app_model
|
||||
def get(self, app_model: App):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
trace_config = OpsService.get_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
trace_config = OpsService.get_tracing_app_config(
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider
|
||||
)
|
||||
if not trace_config:
|
||||
return {"has_not_configured": True}
|
||||
return trace_config
|
||||
@ -65,13 +69,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
"""Create a new trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@ -90,13 +95,14 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def patch(self, app_model: App):
|
||||
"""Update an existing trace app configuration"""
|
||||
args = TraceConfigPayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=str(app_id), tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_model.id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@ -113,12 +119,13 @@ class TraceAppConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id: UUID):
|
||||
@get_app_model
|
||||
def delete(self, app_model: App):
|
||||
"""Delete an existing trace app configuration"""
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=str(app_id), tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import AliasChoices, BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.fields import NewAppResponse, SimpleResultResponse
|
||||
from controllers.common.schema import (
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0,
|
||||
register_response_schema_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
@ -26,6 +28,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.file_access import DatabaseFileAccessController
|
||||
from core.helper import encrypter
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.plugin.impl.exc import PluginInvokeError
|
||||
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
|
||||
@ -38,22 +41,22 @@ from core.trigger.debug.event_selectors import (
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import SimpleAccount
|
||||
from fields.workflow_run_fields import WorkflowRunNodeExecutionResponse
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.file import helpers as file_helpers
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.variables import SecretVariable, SegmentType, VariableBase
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow, WorkflowKind
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
@ -68,42 +71,15 @@ LISTENING_RETRY_IN = 2000
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
from fields.workflow_fields import pipeline_variable_fields, serialize_value_type
|
||||
|
||||
conversation_variable_model = console_ns.model(
|
||||
"ConversationVariable",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"value_type": fields.String(attribute=serialize_value_type),
|
||||
"value": fields.Raw,
|
||||
"description": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
pipeline_variable_model = console_ns.model("PipelineVariable", pipeline_variable_fields)
|
||||
|
||||
# Workflow model with nested dependencies
|
||||
workflow_fields_copy = workflow_fields.copy()
|
||||
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
|
||||
workflow_fields_copy["updated_by"] = fields.Nested(
|
||||
simple_account_model, attribute="updated_by_account", allow_null=True
|
||||
)
|
||||
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
|
||||
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
|
||||
workflow_model = console_ns.model("Workflow", workflow_fields_copy)
|
||||
|
||||
# Workflow pagination model
|
||||
workflow_pagination_fields_copy = workflow_pagination_fields.copy()
|
||||
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
|
||||
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
|
||||
class EnvironmentVariableResponseDict(TypedDict):
|
||||
value_type: str
|
||||
id: NotRequired[str]
|
||||
name: NotRequired[str]
|
||||
value: NotRequired[Any]
|
||||
description: NotRequired[str | None]
|
||||
|
||||
|
||||
class SyncDraftWorkflowPayload(BaseModel):
|
||||
@ -161,23 +137,6 @@ class ConvertToWorkflowPayload(BaseModel):
|
||||
icon_background: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
keyword: str | None = Field(default=None, max_length=255)
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowTypeConvertQuery(BaseModel):
|
||||
target_type: Literal["workflow", "evaluation"]
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
@ -191,6 +150,110 @@ class WorkflowOnlineUsersPayload(BaseModel):
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
|
||||
|
||||
class WorkflowConversationVariableResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: Any = Field(json_schema_extra={"type": "object"})
|
||||
description: str
|
||||
|
||||
@field_validator("value_type", mode="before")
|
||||
@classmethod
|
||||
def _serialize_value_type(cls, value: Any) -> str:
|
||||
if hasattr(value, "exposed_type"):
|
||||
return str(value.exposed_type())
|
||||
return str(value)
|
||||
|
||||
|
||||
class PipelineVariableResponse(ResponseModel):
|
||||
label: str
|
||||
variable: str
|
||||
type: str
|
||||
belong_to_node_id: str
|
||||
max_length: int | None = None
|
||||
required: bool
|
||||
unit: str | None = None
|
||||
default_value: Any = Field(default=None, json_schema_extra={"type": "object"})
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
tooltips: str | None = None
|
||||
allowed_file_types: list[str] | None = None
|
||||
allowed_file_extensions: list[str] | None = Field(
|
||||
default=None, validation_alias=AliasChoices("allowed_file_extensions", "allow_file_extension")
|
||||
)
|
||||
allowed_file_upload_methods: list[str] | None = Field(
|
||||
default=None, validation_alias=AliasChoices("allowed_file_upload_methods", "allow_file_upload_methods")
|
||||
)
|
||||
|
||||
|
||||
class WorkflowEnvironmentVariableResponse(ResponseModel):
|
||||
value_type: str
|
||||
id: str
|
||||
name: str
|
||||
value: Any = Field(json_schema_extra={"type": "object"})
|
||||
description: str
|
||||
|
||||
|
||||
class WorkflowResponse(ResponseModel):
|
||||
id: str
|
||||
graph: dict[str, Any] = Field(validation_alias=AliasChoices("graph_dict", "graph"))
|
||||
features: dict[str, Any] = Field(validation_alias=AliasChoices("features_dict", "features"))
|
||||
hash: str = Field(validation_alias=AliasChoices("unique_hash", "hash"))
|
||||
version: str
|
||||
marked_name: str
|
||||
marked_comment: str
|
||||
created_by: SimpleAccount | None = Field(
|
||||
default=None, validation_alias=AliasChoices("created_by_account", "created_by")
|
||||
)
|
||||
created_at: int
|
||||
updated_by: SimpleAccount | None = Field(
|
||||
default=None, validation_alias=AliasChoices("updated_by_account", "updated_by")
|
||||
)
|
||||
updated_at: int
|
||||
tool_published: bool
|
||||
environment_variables: list[WorkflowEnvironmentVariableResponse]
|
||||
conversation_variables: list[WorkflowConversationVariableResponse]
|
||||
rag_pipeline_variables: list[PipelineVariableResponse]
|
||||
|
||||
@field_validator("created_at", "updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int:
|
||||
timestamp = to_timestamp(value)
|
||||
if timestamp is None:
|
||||
raise ValueError("timestamp is required")
|
||||
return timestamp
|
||||
|
||||
@field_validator("environment_variables", mode="before")
|
||||
@classmethod
|
||||
def _serialize_environment_variables(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
return [_serialize_environment_variable(item) for item in value]
|
||||
|
||||
|
||||
class WorkflowPaginationResponse(ResponseModel):
|
||||
items: list[WorkflowResponse]
|
||||
page: int
|
||||
limit: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class WorkflowOnlineUser(ResponseModel):
|
||||
user_id: str
|
||||
username: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class WorkflowOnlineUsersByApp(ResponseModel):
|
||||
app_id: str
|
||||
users: list[WorkflowOnlineUser]
|
||||
|
||||
|
||||
class WorkflowOnlineUsersResponse(ResponseModel):
|
||||
data: list[WorkflowOnlineUsersByApp]
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -199,27 +262,38 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(SyncDraftWorkflowPayload)
|
||||
reg(AdvancedChatWorkflowRunPayload)
|
||||
reg(IterationNodeRunPayload)
|
||||
reg(LoopNodeRunPayload)
|
||||
reg(DraftWorkflowRunPayload)
|
||||
reg(DraftWorkflowNodeRunPayload)
|
||||
reg(PublishWorkflowPayload)
|
||||
reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowTypeConvertQuery)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SyncDraftWorkflowPayload,
|
||||
AdvancedChatWorkflowRunPayload,
|
||||
IterationNodeRunPayload,
|
||||
LoopNodeRunPayload,
|
||||
DraftWorkflowRunPayload,
|
||||
DraftWorkflowNodeRunPayload,
|
||||
PublishWorkflowPayload,
|
||||
DefaultBlockConfigQuery,
|
||||
ConvertToWorkflowPayload,
|
||||
WorkflowListQuery,
|
||||
WorkflowUpdatePayload,
|
||||
WorkflowFeaturesPayload,
|
||||
WorkflowOnlineUsersPayload,
|
||||
DraftWorkflowTriggerRunPayload,
|
||||
DraftWorkflowTriggerRunAllPayload,
|
||||
)
|
||||
register_response_schema_model(console_ns, WorkflowRunNodeExecutionResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowConversationVariableResponse,
|
||||
PipelineVariableResponse,
|
||||
WorkflowEnvironmentVariableResponse,
|
||||
WorkflowResponse,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowOnlineUser,
|
||||
WorkflowOnlineUsersByApp,
|
||||
WorkflowOnlineUsersResponse,
|
||||
NewAppResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -241,18 +315,56 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
|
||||
return file_objs
|
||||
|
||||
|
||||
def _serialize_environment_variable(value: Any) -> EnvironmentVariableResponseDict | Any:
|
||||
match value:
|
||||
case SecretVariable():
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": encrypter.full_mask_token(),
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
|
||||
case VariableBase():
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": value.value,
|
||||
"value_type": str(value.value_type.exposed_type()),
|
||||
"description": value.description,
|
||||
}
|
||||
|
||||
case dict():
|
||||
value_type_str = value.get("value_type")
|
||||
if not isinstance(value_type_str, str):
|
||||
raise TypeError(
|
||||
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
|
||||
)
|
||||
value_type = SegmentType(value_type_str).exposed_type()
|
||||
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
|
||||
raise ValueError(f"Unsupported environment variable value type: {value_type}")
|
||||
return value
|
||||
|
||||
case _:
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft")
|
||||
class DraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_draft_workflow")
|
||||
@console_ns.doc(description="Get draft workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow retrieved successfully",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -265,8 +377,8 @@ class DraftWorkflowApi(Resource):
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# return workflow, if not found, return None (initiate graph by frontend)
|
||||
return workflow
|
||||
# return workflow, if not found, return 404
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -760,7 +872,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_workflow_task")
|
||||
@console_ns.doc(description="Stop running workflow task")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(404, "Task not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -840,13 +952,15 @@ class PublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_published_workflow")
|
||||
@console_ns.doc(description="Get published workflow for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Published workflow not found")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflow retrieved successfully, or null if not found",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@ -857,7 +971,10 @@ class PublishedWorkflowApi(Resource):
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
if workflow is None:
|
||||
return None
|
||||
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@setup_required
|
||||
@ -898,54 +1015,6 @@ class PublishedWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/publish/evaluation")
|
||||
class EvaluationPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("publish_evaluation_workflow")
|
||||
@console_ns.doc(description="Publish draft workflow as evaluation workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__])
|
||||
@console_ns.response(200, "Evaluation workflow published successfully")
|
||||
@console_ns.response(400, "Invalid workflow or unsupported node type")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Publish draft workflow as evaluation workflow.
|
||||
|
||||
Evaluation workflows cannot include trigger or human-input nodes.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.publish_evaluation_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
# Keep workflow_id aligned with the latest published workflow.
|
||||
app_model_in_session = session.get(App, app_model.id)
|
||||
if app_model_in_session:
|
||||
app_model_in_session.workflow_id = workflow.id
|
||||
app_model_in_session.updated_by = current_user.id
|
||||
app_model_in_session.updated_at = naive_utc_now()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_default_block_configs")
|
||||
@ -1003,7 +1072,11 @@ class ConvertToWorkflowApi(Resource):
|
||||
@console_ns.doc("convert_to_workflow")
|
||||
@console_ns.doc(description="Convert application to workflow mode")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Application converted to workflow successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Application converted to workflow successfully",
|
||||
console_ns.models[NewAppResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Application cannot be converted")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -1040,7 +1113,11 @@ class WorkflowFeaturesApi(Resource):
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow features updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -1064,7 +1141,11 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.doc("get_all_published_workflows")
|
||||
@console_ns.doc(description="Get all published workflows for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflows retrieved successfully",
|
||||
console_ns.models[WorkflowPaginationResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -1096,14 +1177,14 @@ class PublishedAllWorkflowApi(Resource):
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
serialized_workflows = marshal(workflows, workflow_fields_copy)
|
||||
|
||||
return {
|
||||
"items": serialized_workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>/restore")
|
||||
@ -1143,66 +1224,19 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/convert-type")
|
||||
class WorkflowTypeConvertApi(Resource):
|
||||
@console_ns.doc("convert_published_workflow_type")
|
||||
@console_ns.doc(description="Convert current effective published workflow type in-place")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowTypeConvertQuery.__name__])
|
||||
@console_ns.response(200, "Workflow type converted successfully")
|
||||
@console_ns.response(400, "Invalid workflow type or unsupported workflow graph")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowTypeConvertQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
target_type = WorkflowKind.EVALUATION if args.target_type == "evaluation" else WorkflowKind.STANDARD
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
workflow = workflow_service.convert_published_workflow_type(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
target_type=target_type,
|
||||
account=current_user,
|
||||
)
|
||||
except WorkflowNotFoundError as exc:
|
||||
raise NotFound(str(exc)) from exc
|
||||
except IsDraftWorkflowError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc)) from exc
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"workflow_id": workflow.id,
|
||||
"type": workflow.type.value,
|
||||
"kind": workflow.kind_or_standard,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@console_ns.doc("update_workflow_by_id")
|
||||
@console_ns.doc(description="Update workflow by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Workflow updated successfully", workflow_model)
|
||||
@console_ns.response(200, "Workflow updated successfully", console_ns.models[WorkflowResponse.__name__])
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@ -1236,7 +1270,7 @@ class WorkflowByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
return workflow
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -1521,12 +1555,16 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Workflow online users retrieved successfully",
|
||||
console_ns.models[WorkflowOnlineUsersResponse.__name__],
|
||||
)
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -1569,10 +1607,18 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
if not isinstance(user_info, dict):
|
||||
continue
|
||||
|
||||
user_id = user_info.get("user_id")
|
||||
username = user_info.get("username")
|
||||
if not isinstance(user_id, str) or not isinstance(username, str):
|
||||
continue
|
||||
|
||||
avatar = user_info.get("avatar")
|
||||
if avatar is not None and not isinstance(avatar, str):
|
||||
avatar = None
|
||||
|
||||
if isinstance(avatar, str) and avatar and not avatar.startswith(("http://", "https://")):
|
||||
try:
|
||||
user_info["avatar"] = file_helpers.get_signed_file_url(avatar)
|
||||
avatar = file_helpers.get_signed_file_url(avatar)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to sign workflow online user avatar; using original value. "
|
||||
@ -1582,7 +1628,7 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
exc,
|
||||
)
|
||||
|
||||
users.append(user_info)
|
||||
users.append({"user_id": user_id, "username": username, "avatar": avatar})
|
||||
results.append({"app_id": app_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
return WorkflowOnlineUsersResponse.model_validate({"data": results}).model_dump(mode="json")
|
||||
|
||||
@ -16,6 +16,7 @@ from fields.base import ResponseModel
|
||||
from fields.end_user_fields import SimpleEndUser
|
||||
from fields.member_fields import SimpleAccount
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
@ -82,9 +83,7 @@ class WorkflowRunForLogResponse(ResponseModel):
|
||||
@field_validator("created_at", "finished_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
@ -104,28 +103,10 @@ class WorkflowRunForArchivedLogResponse(ResponseModel):
|
||||
return str(getattr(value, "value", value))
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationNodeInfoResponse(ResponseModel):
|
||||
node_id: str
|
||||
type: str
|
||||
title: str
|
||||
|
||||
|
||||
class WorkflowAppLogEvaluationItemResponse(ResponseModel):
|
||||
name: str
|
||||
value: Any = None
|
||||
details: dict[str, Any] | None = None
|
||||
node_info: WorkflowAppLogEvaluationNodeInfoResponse | None = Field(
|
||||
default=None,
|
||||
validation_alias="node_info",
|
||||
serialization_alias="nodeInfo",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
id: str
|
||||
workflow_run: WorkflowRunForLogResponse | None = None
|
||||
details: Any = None
|
||||
evaluation: list[WorkflowAppLogEvaluationItemResponse] = Field(default_factory=list)
|
||||
created_from: str | None = None
|
||||
created_by_role: str | None = None
|
||||
created_by_account: SimpleAccount | None = None
|
||||
@ -135,14 +116,7 @@ class WorkflowAppLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
@field_validator("evaluation", mode="before")
|
||||
@classmethod
|
||||
def _normalize_evaluation(cls, value: Any) -> list[dict[str, Any]] | list[WorkflowAppLogEvaluationItemResponse]:
|
||||
return value or []
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@ -156,9 +130,7 @@ class WorkflowArchivedLogPartialResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowAppLogPaginationResponse(ResponseModel):
|
||||
@ -182,8 +154,6 @@ register_schema_models(
|
||||
WorkflowAppLogQuery,
|
||||
WorkflowRunForLogResponse,
|
||||
WorkflowRunForArchivedLogResponse,
|
||||
WorkflowAppLogEvaluationNodeInfoResponse,
|
||||
WorkflowAppLogEvaluationItemResponse,
|
||||
WorkflowAppLogPartialResponse,
|
||||
WorkflowArchivedLogPartialResponse,
|
||||
WorkflowAppLogPaginationResponse,
|
||||
|
||||
@ -1,22 +1,16 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, computed_field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.helper import build_avatar_url, dump_response, to_timestamp
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
@ -51,6 +45,138 @@ class WorkflowCommentMentionUsersPayload(BaseModel):
|
||||
users: list[AccountWithRole]
|
||||
|
||||
|
||||
class WorkflowCommentAccount(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
avatar: str | None = Field(default=None, exclude=True)
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
|
||||
@property
|
||||
def avatar_url(self) -> str | None:
|
||||
return build_avatar_url(self.avatar)
|
||||
|
||||
|
||||
class WorkflowCommentReply(ResponseModel):
|
||||
id: str
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentMention(ResponseModel):
|
||||
mentioned_user_id: str
|
||||
mentioned_user_account: WorkflowCommentAccount | None = None
|
||||
reply_id: str | None = None
|
||||
|
||||
|
||||
class WorkflowCommentBasic(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
reply_count: int
|
||||
mention_count: int
|
||||
participants: list[WorkflowCommentAccount]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentBasicList(ResponseModel):
|
||||
data: list[WorkflowCommentBasic]
|
||||
|
||||
|
||||
class WorkflowCommentDetail(ResponseModel):
|
||||
id: str
|
||||
position_x: float
|
||||
position_y: float
|
||||
content: str
|
||||
created_by: str
|
||||
created_by_account: WorkflowCommentAccount | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
resolved_by_account: WorkflowCommentAccount | None = None
|
||||
replies: list[WorkflowCommentReply]
|
||||
mentions: list[WorkflowCommentMention]
|
||||
|
||||
@field_validator("created_at", "updated_at", "resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentResolve(ResponseModel):
|
||||
id: str
|
||||
resolved: bool
|
||||
resolved_at: int | None = None
|
||||
resolved_by: str | None = None
|
||||
|
||||
@field_validator("resolved_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreate(ResponseModel):
|
||||
id: str
|
||||
created_at: int | None = None
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdate(ResponseModel):
|
||||
id: str
|
||||
updated_at: int | None = None
|
||||
|
||||
@field_validator("updated_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountWithRole,
|
||||
@ -59,17 +185,19 @@ register_schema_models(
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyPayload,
|
||||
)
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
WorkflowCommentAccount,
|
||||
WorkflowCommentReply,
|
||||
WorkflowCommentMention,
|
||||
WorkflowCommentBasic,
|
||||
WorkflowCommentBasicList,
|
||||
WorkflowCommentDetail,
|
||||
WorkflowCommentCreate,
|
||||
WorkflowCommentUpdate,
|
||||
WorkflowCommentResolve,
|
||||
WorkflowCommentReplyCreate,
|
||||
WorkflowCommentReplyUpdate,
|
||||
)
|
||||
|
||||
|
||||
@ -80,28 +208,26 @@ class WorkflowCommentListApi(Resource):
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@console_ns.response(200, "Comments retrieved successfully", console_ns.models[WorkflowCommentBasicList.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
return WorkflowCommentBasicList.model_validate({"data": comments}).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@console_ns.response(201, "Comment created successfully", console_ns.models[WorkflowCommentCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
@ -117,7 +243,7 @@ class WorkflowCommentListApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
@ -127,30 +253,28 @@ class WorkflowCommentDetailApi(Resource):
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@console_ns.response(200, "Comment retrieved successfully", console_ns.models[WorkflowCommentDetail.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentDetail, comment)
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@console_ns.response(200, "Comment updated successfully", console_ns.models[WorkflowCommentUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
@ -167,7 +291,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
return dump_response(WorkflowCommentUpdate, result)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@ -197,12 +321,11 @@ class WorkflowCommentResolveApi(Resource):
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@console_ns.response(200, "Comment resolved successfully", console_ns.models[WorkflowCommentResolve.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
@ -213,7 +336,7 @@ class WorkflowCommentResolveApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
return dump_response(WorkflowCommentResolve, comment)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
@ -224,12 +347,11 @@ class WorkflowCommentReplyApi(Resource):
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@console_ns.response(201, "Reply created successfully", console_ns.models[WorkflowCommentReplyCreate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
@ -247,7 +369,7 @@ class WorkflowCommentReplyApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
return dump_response(WorkflowCommentReplyCreate, result), 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
@ -258,12 +380,11 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyPayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@console_ns.response(200, "Reply updated successfully", console_ns.models[WorkflowCommentReplyUpdate.__name__])
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
@ -284,7 +405,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
return dump_response(WorkflowCommentReplyUpdate, reply)
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -40,16 +38,29 @@ class ActivatePayload(BaseModel):
|
||||
return timezone(value)
|
||||
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: dict[str, Any] | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
class ActivationResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
|
||||
|
||||
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
|
||||
class ActivationCheckData(BaseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
email: str | None
|
||||
|
||||
|
||||
class ActivationCheckResponse(BaseModel):
|
||||
is_valid: bool = Field(description="Whether token is valid")
|
||||
data: ActivationCheckData | None = Field(default=None, description="Activation data if valid")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ActivateCheckQuery,
|
||||
ActivatePayload,
|
||||
ActivationCheckData,
|
||||
ActivationCheckResponse,
|
||||
ActivationResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
@ -16,11 +17,26 @@ class ApiKeyAuthBindingPayload(BaseModel):
|
||||
credentials: dict = Field(...)
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceItem(ResponseModel):
|
||||
id: str
|
||||
category: str
|
||||
provider: str
|
||||
disabled: bool
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceListResponse(ResponseModel):
|
||||
sources: list[ApiKeyAuthDataSourceItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyAuthBindingPayload)
|
||||
register_response_schema_models(console_ns, ApiKeyAuthDataSourceItem, ApiKeyAuthDataSourceListResponse)
|
||||
|
||||
|
||||
@console_ns.route("/api-key-auth/data-source")
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ApiKeyAuthDataSourceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -70,6 +86,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -4,7 +4,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language, languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultDataResponse, VerificationTokenResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -58,6 +59,7 @@ class EmailRegisterResetPayload(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultDataResponse, VerificationTokenResponse)
|
||||
|
||||
|
||||
@console_ns.route("/email-register/send-email")
|
||||
@ -65,6 +67,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -89,6 +92,7 @@ class EmailRegisterCheckApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[VerificationTokenResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
|
||||
|
||||
|
||||
@ -9,7 +9,8 @@ from werkzeug.exceptions import Unauthorized
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultDataResponse, SimpleResultOptionalDataResponse, SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
@ -81,6 +82,12 @@ class EmailCodeLoginPayload(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultOptionalDataResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/login")
|
||||
@ -90,6 +97,7 @@ class LoginApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[LoginPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultOptionalDataResponse.__name__])
|
||||
@decrypt_password_field
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
@ -163,6 +171,7 @@ class LoginApi(Resource):
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@ -186,6 +195,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -213,6 +223,7 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def post(self):
|
||||
args = EmailPayload.model_validate(console_ns.payload)
|
||||
normalized_email = args.email.lower()
|
||||
@ -245,6 +256,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.expect(console_ns.models[EmailCodeLoginPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@decrypt_code_field
|
||||
def post(self):
|
||||
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
|
||||
@ -321,6 +333,7 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
@ -12,7 +10,6 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
@ -80,39 +77,3 @@ class PartnerTenants(Resource):
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
|
||||
|
||||
_DEBUG_KEY = "billing:debug"
|
||||
_DEBUG_TTL = timedelta(days=7)
|
||||
|
||||
|
||||
class DebugDataPayload(BaseModel):
|
||||
type: str = Field(..., min_length=1, description="Data type key")
|
||||
data: str = Field(..., min_length=1, description="Data value to append")
|
||||
|
||||
|
||||
@console_ns.route("/billing/debug/data")
|
||||
class DebugData(Resource):
|
||||
def post(self):
|
||||
body = DebugDataPayload.model_validate(request.get_json(force=True))
|
||||
item = json.dumps({
|
||||
"type": body.type,
|
||||
"data": body.data,
|
||||
"createTime": datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
})
|
||||
redis_client.lpush(_DEBUG_KEY, item)
|
||||
redis_client.expire(_DEBUG_KEY, _DEBUG_TTL)
|
||||
return {"result": "ok"}, 201
|
||||
|
||||
def get(self):
|
||||
recent = request.args.get("recent", 10, type=int)
|
||||
items = redis_client.lrange(_DEBUG_KEY, 0, recent - 1)
|
||||
return {
|
||||
"data": [
|
||||
json.loads(item.decode("utf-8") if isinstance(item, bytes) else item) for item in items
|
||||
]
|
||||
}
|
||||
|
||||
def delete(self):
|
||||
redis_client.delete(_DEBUG_KEY)
|
||||
return {"result": "ok"}
|
||||
|
||||
@ -9,7 +9,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_model
|
||||
from controllers.common.fields import SimpleResultResponse, TextContentResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
@ -54,6 +55,7 @@ class DataSourceNotionPreviewQuery(BaseModel):
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
@ -157,6 +159,7 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
@ -289,6 +292,7 @@ class DataSourceNotionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -362,6 +366,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -379,6 +384,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -24,7 +22,6 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
@ -33,7 +30,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
@ -56,21 +52,14 @@ from fields.document_fields import document_status_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermission, DatasetPermissionEnum
|
||||
from models.enums import ApiTokenType, SegmentStatus
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
@ -141,14 +130,6 @@ def _validate_doc_form(value: str | None) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
def _ensure_permission_keys(dataset: Dataset, *, enabled: bool) -> None:
|
||||
if not enabled:
|
||||
setattr(dataset, "permission_keys", [])
|
||||
return
|
||||
if not isinstance(getattr(dataset, "permission_keys", None), list):
|
||||
setattr(dataset, "permission_keys", [])
|
||||
|
||||
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field("", max_length=400)
|
||||
@ -351,19 +332,6 @@ class DatasetListApi(Resource):
|
||||
query.include_all,
|
||||
)
|
||||
|
||||
for dataset in datasets:
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
|
||||
if dify_config.RBAC_ENABLED and datasets:
|
||||
dataset_ids = [str(dataset.id) for dataset in datasets]
|
||||
permission_keys_map = enterprise_rbac_service.RBACService.DatasetPermissions.batch_get(
|
||||
str(current_tenant_id),
|
||||
current_user.id,
|
||||
dataset_ids,
|
||||
)
|
||||
for dataset in datasets:
|
||||
setattr(dataset, "permission_keys", permission_keys_map.get(str(dataset.id), []))
|
||||
|
||||
# check embedding setting
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id)
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
|
||||
@ -445,7 +413,6 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@ -470,7 +437,6 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
@ -540,7 +506,6 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
_ensure_permission_keys(dataset, enabled=dify_config.RBAC_ENABLED)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@ -559,6 +524,7 @@ class DatasetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -581,7 +547,11 @@ class DatasetUseCheckApi(Resource):
|
||||
@console_ns.doc("check_dataset_use")
|
||||
@console_ns.doc(description="Check if dataset is in use")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset use status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset use status retrieved successfully",
|
||||
console_ns.models[UsageCheckResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -911,6 +881,7 @@ class DatasetEnableApiApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
@ -923,7 +894,7 @@ class DatasetEnableApiApi(Resource):
|
||||
class DatasetApiBaseUrlApi(Resource):
|
||||
@console_ns.doc("get_dataset_api_base_info")
|
||||
@console_ns.doc(description="Get dataset API base information")
|
||||
@console_ns.response(200, "API base info retrieved successfully")
|
||||
@console_ns.response(200, "API base info retrieved successfully", console_ns.models[ApiBaseUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -1023,432 +994,3 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
|
||||
# ---- Knowledge Base Retrieval Evaluation ----
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run(run: EvaluationRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": json.loads(run.metrics_summary) if run.metrics_summary else {},
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_dataset_evaluation_run_item(item: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/template/download")
|
||||
class DatasetEvaluationTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Download evaluation dataset template for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
xlsx_content, filename = EvaluationService.generate_retrieval_dataset_template()
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation")
|
||||
class DatasetEvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(
|
||||
session, current_tenant_id, "dataset", dataset_id_str
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
@console_ns.doc("save_dataset_evaluation_config")
|
||||
@console_ns.response(200, "Evaluation configuration saved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id):
|
||||
"""Save evaluation configuration for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": config.default_metrics_list,
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": config.judgment_config_dict,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/run")
|
||||
class DatasetEvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
"""Start an evaluation run for the knowledge base retrieval."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=current_tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=EvaluationTargetType.KNOWLEDGE_BASE,
|
||||
target_id=dataset_id_str,
|
||||
account_id=str(current_user.id),
|
||||
dataset_file_content=dataset_content,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/logs")
|
||||
class DatasetEvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get evaluation run history for the knowledge base."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type="dataset",
|
||||
target_id=dataset_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_dataset_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>")
|
||||
class DatasetEvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, run_id):
|
||||
"""Get evaluation run detail including per-item results."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id_str,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return {
|
||||
"run": _serialize_dataset_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_dataset_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class DatasetEvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_dataset_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id, run_id):
|
||||
"""Cancel a running knowledge base evaluation."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
run_id_str = str(run_id)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id_str,
|
||||
)
|
||||
return _serialize_dataset_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/metrics")
|
||||
class DatasetEvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_dataset_evaluation_metrics")
|
||||
@console_ns.response(200, "Available retrieval metrics retrieved")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id):
|
||||
"""Get available evaluation metrics for knowledge base retrieval."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return {
|
||||
"metrics": EvaluationService.get_supported_metrics(EvaluationCategory.RETRIEVAL)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/evaluation/files/<uuid:file_id>")
|
||||
class DatasetEvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_dataset_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, file_id):
|
||||
"""Download evaluation test file or result file for the knowledge base."""
|
||||
from core.workflow.file import helpers as file_helpers
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
file_id_str = str(file_id)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id_str,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found.")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
@ -3,18 +3,20 @@ import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import SimpleResultMessageResponse, SimpleResultResponse, UrlResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
@ -29,17 +31,16 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_metadata_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
@ -72,27 +73,94 @@ from ..wraps import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = get_or_create_model("Dataset", dataset_fields)
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
document_metadata_model = get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = get_or_create_model("Document", document_fields_copy)
|
||||
class DatasetResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
permission: str | None = None
|
||||
data_source_type: str | None = None
|
||||
indexing_technique: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
@field_validator("data_source_type", "indexing_technique", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class DocumentResponse(ResponseModel):
|
||||
id: str
|
||||
position: int | None = None
|
||||
data_source_type: str | None = None
|
||||
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
|
||||
data_source_detail_dict: Any = None
|
||||
dataset_process_rule_id: str | None = None
|
||||
name: str
|
||||
created_from: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
tokens: int | None = None
|
||||
indexing_status: str | None = None
|
||||
error: str | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
archived: bool | None = None
|
||||
display_status: str | None = None
|
||||
word_count: int | None = None
|
||||
hit_count: int | None = None
|
||||
doc_form: str | None = None
|
||||
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
|
||||
summary_index_status: str | None = None
|
||||
need_summary: bool | None = None
|
||||
|
||||
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
@field_validator("doc_metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
process_rule_dict: Any = None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
|
||||
class DatasetAndDocumentResponse(ResponseModel):
|
||||
dataset: DatasetResponse
|
||||
documents: list[DocumentResponse]
|
||||
batch: str
|
||||
|
||||
|
||||
class DocumentRetryPayload(BaseModel):
|
||||
@ -107,6 +175,11 @@ class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentMetadataUpdatePayload(BaseModel):
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any = None
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
@ -124,8 +197,15 @@ register_schema_models(
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
DocumentMetadataUpdatePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
DatasetResponse,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
@ -360,10 +440,10 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
@ -401,12 +481,15 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Documents deleted successfully")
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -429,12 +512,13 @@ class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(
|
||||
201, "Dataset initialized successfully", console_ns.models[DatasetAndDocumentResponse.__name__]
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@ -482,9 +566,9 @@ class DatasetInitApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
response = {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
|
||||
return response
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
@ -865,6 +949,7 @@ class DocumentApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
@ -890,6 +975,7 @@ class DocumentDownloadApi(DocumentResource):
|
||||
|
||||
@console_ns.doc("get_dataset_document_download_url")
|
||||
@console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
|
||||
@console_ns.response(200, "Download URL generated successfully", console_ns.models[UrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -947,7 +1033,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@console_ns.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
|
||||
)
|
||||
@console_ns.response(200, "Processing status updated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Processing status updated successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Invalid action")
|
||||
@setup_required
|
||||
@ -991,16 +1081,12 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@console_ns.doc("update_document_metadata")
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"UpdateDocumentMetadataRequest",
|
||||
{
|
||||
"doc_type": fields.String(description="Document type"),
|
||||
"doc_metadata": fields.Raw(description="Document metadata"),
|
||||
},
|
||||
)
|
||||
@console_ns.expect(console_ns.models[DocumentMetadataUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Document metadata updated successfully",
|
||||
console_ns.models[SimpleResultMessageResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -1012,10 +1098,10 @@ class DocumentMetadataApi(DocumentResource):
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
req_data = request.get_json()
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
doc_type = req_data.get("doc_type")
|
||||
doc_metadata = req_data.get("doc_metadata")
|
||||
doc_type = req_data.doc_type
|
||||
doc_metadata = req_data.doc_metadata
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1054,6 +1140,7 @@ class DocumentStatusApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1091,6 +1178,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document paused successfully")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1125,6 +1213,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document resumed successfully")
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id = str(dataset_id)
|
||||
@ -1157,6 +1246,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
@console_ns.response(204, "Documents retry started successfully")
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
@ -1197,7 +1287,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(document_model)
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
@ -1215,7 +1305,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return document
|
||||
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
@ -1223,6 +1313,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -1289,7 +1380,11 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Summary generation started successfully",
|
||||
console_ns.models[SimpleResultResponse.__name__],
|
||||
)
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
|
||||
@ -10,7 +10,8 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@ -30,6 +31,7 @@ from core.model_manager import ModelManager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import escape_like_pattern
|
||||
@ -83,6 +85,11 @@ class BatchImportPayload(BaseModel):
|
||||
upload_file_id: str
|
||||
|
||||
|
||||
class SegmentBatchImportStatusResponse(ResponseModel):
|
||||
job_id: str
|
||||
job_status: str
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
@ -98,6 +105,7 @@ register_schema_models(
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
@ -217,6 +225,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@ -252,6 +261,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -424,6 +434,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -464,6 +475,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
"/datasets/batch_import_status/<uuid:job_id>",
|
||||
)
|
||||
class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@console_ns.response(200, "Batch import started", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -514,6 +526,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
|
||||
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -691,6 +704,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import UsageCountResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@ -27,6 +28,8 @@ from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService
|
||||
|
||||
register_response_schema_models(console_ns, UsageCountResponse)
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
@ -206,6 +209,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
@ -222,7 +226,7 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@console_ns.doc("check_external_api_usage")
|
||||
@console_ns.doc(description="Check if external knowledge API is being used")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "Usage check completed successfully")
|
||||
@console_ns.response(200, "Usage check completed successfully", console_ns.models[UsageCountResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -19,12 +20,6 @@ from ..wraps import (
|
||||
)
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
@ -61,7 +56,7 @@ class HitTestingSegment(ResponseModel):
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
|
||||
@ -39,11 +39,8 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
@ -52,15 +49,15 @@ class DatasetsHitTestingBase:
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
raise ValueError("Invalid hit testing records response")
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
raise ValueError("Invalid hit testing record response")
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
@ -118,8 +115,8 @@ class DatasetsHitTestingBase:
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
|
||||
@ -4,7 +4,8 @@ from flask_restx import Resource, marshal_with
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
@ -21,6 +22,7 @@ from services.metadata_service import MetadataService
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -83,6 +85,7 @@ class DatasetMetadataApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -113,6 +116,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -136,6 +140,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
@ -6,7 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
@ -56,6 +57,7 @@ register_schema_models(
|
||||
DatasourceDefaultPayload,
|
||||
DatasourceUpdateNamePayload,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
@ -209,6 +211,7 @@ class DatasourceAuth(Resource):
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -306,6 +309,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@ -321,6 +325,7 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -342,6 +347,7 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -6,7 +6,8 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -59,6 +60,7 @@ class Payload(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
register_response_schema_models(console_ns, SimpleDataResponse)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
@ -85,6 +87,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
|
||||
def post(self, template_id: str):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = session.scalar(
|
||||
|
||||
@ -3,13 +3,14 @@ import logging
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -19,8 +20,8 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.app.workflow import (
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
workflow_model,
|
||||
workflow_pagination_model,
|
||||
WorkflowPaginationResponse,
|
||||
WorkflowResponse,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
@ -34,6 +35,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.base import ResponseModel
|
||||
from fields.workflow_run_fields import (
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
@ -42,7 +44,7 @@ from fields.workflow_run_fields import (
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
@ -115,6 +117,17 @@ class RagPipelineRecommendedPluginQuery(BaseModel):
|
||||
type: str = "all"
|
||||
|
||||
|
||||
class RagPipelineWorkflowSyncResponse(ResponseModel):
|
||||
result: str
|
||||
hash: str
|
||||
updated_at: int
|
||||
|
||||
|
||||
class RagPipelineWorkflowPublishResponse(ResponseModel):
|
||||
result: str
|
||||
created_at: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
@ -133,6 +146,9 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
RagPipelineWorkflowPublishResponse,
|
||||
RagPipelineWorkflowSyncResponse,
|
||||
SimpleResultResponse,
|
||||
WorkflowRunDetailResponse,
|
||||
WorkflowRunNodeExecutionListResponse,
|
||||
WorkflowRunNodeExecutionResponse,
|
||||
@ -142,12 +158,17 @@ register_response_schema_models(
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Draft workflow retrieved successfully",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft rag pipeline's workflow
|
||||
@ -159,14 +180,15 @@ class DraftRagPipelineApi(Resource):
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# return workflow, if not found, return None (initiate graph by frontend)
|
||||
return workflow
|
||||
# return workflow, if not found, return 404
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
@edit_permission_required
|
||||
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__])
|
||||
def post(self, pipeline: Pipeline):
|
||||
"""
|
||||
Sync draft workflow
|
||||
@ -457,6 +479,7 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class RagPipelineTaskStopApi(Resource):
|
||||
@console_ns.response(200, "Task stopped successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@ -476,12 +499,16 @@ class RagPipelineTaskStopApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/publish")
|
||||
class PublishedRagPipelineApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflow retrieved successfully, or null if not exist",
|
||||
console_ns.models[WorkflowResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published pipeline
|
||||
@ -494,8 +521,12 @@ class PublishedRagPipelineApi(Resource):
|
||||
workflow = rag_pipeline_service.get_published_workflow(pipeline=pipeline)
|
||||
|
||||
# return workflow, if not found, return None
|
||||
return workflow
|
||||
if workflow is None:
|
||||
return None
|
||||
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowPublishResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -567,12 +598,17 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Published workflows retrieved successfully",
|
||||
console_ns.models[WorkflowPaginationResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_pagination_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get published workflows
|
||||
@ -601,16 +637,19 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
named_only=named_only,
|
||||
)
|
||||
|
||||
return {
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
return WorkflowPaginationResponse.model_validate(
|
||||
{
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>/restore")
|
||||
class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RagPipelineWorkflowSyncResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -641,12 +680,15 @@ class RagPipelineDraftWorkflowRestoreApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@console_ns.response(200, "Workflow updated successfully", console_ns.models[WorkflowResponse.__name__])
|
||||
@console_ns.response(400, "No valid fields to update")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_rag_pipeline
|
||||
@marshal_with(workflow_model)
|
||||
def patch(self, pipeline: Pipeline, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
@ -675,8 +717,9 @@ class RagPipelineByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
return workflow
|
||||
return dump_response(WorkflowResponse, workflow)
|
||||
|
||||
@console_ns.response(204, "Workflow deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# Evaluation controller module
|
||||
@ -1,993 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.workflow import WorkflowListQuery
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.member_fields import simple_account_fields
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, Dataset
|
||||
from models.evaluation import EvaluationTargetType
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.evaluation import (
|
||||
EvaluationDatasetInvalidError,
|
||||
EvaluationFrameworkNotConfiguredError,
|
||||
EvaluationMaxConcurrentRunsError,
|
||||
EvaluationNotFoundError,
|
||||
)
|
||||
from services.evaluation_service import EvaluationService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.evaluation import EvaluationRun, EvaluationRunItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVALUATE_TARGET_TYPES = {
|
||||
EvaluationTargetType.APPS.value,
|
||||
EvaluationTargetType.SNIPPETS.value,
|
||||
}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_default_metric_node_info_fields = {
|
||||
"node_id": fields.String,
|
||||
"type": fields.String,
|
||||
"title": fields.String,
|
||||
}
|
||||
evaluation_default_metric_item_fields = {
|
||||
"metric": fields.String,
|
||||
"value_type": fields.String,
|
||||
"node_info_list": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model("EvaluationDefaultMetricNodeInfo", evaluation_default_metric_node_info_fields),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
customized_metrics_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
judgment_condition_fields = {
|
||||
"variable_selector": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgment_config_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("JudgmentCondition", judgment_condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem_Detail", evaluation_default_metric_item_fields)),
|
||||
allow_null=True,
|
||||
),
|
||||
"customized_metrics": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMetrics", customized_metrics_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgment_config": fields.Nested(
|
||||
console_ns.model("EvaluationJudgmentConfig", judgment_config_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
available_evaluation_workflow_list_fields = {
|
||||
"id": fields.String,
|
||||
"app_id": fields.String,
|
||||
"app_name": fields.String,
|
||||
"type": fields.String,
|
||||
"kind": fields.String,
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"hash": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_fields = {
|
||||
"items": fields.List(fields.Nested(available_evaluation_workflow_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
|
||||
available_evaluation_workflow_pagination_model = console_ns.model(
|
||||
"AvailableEvaluationWorkflowPagination",
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
)
|
||||
|
||||
evaluation_default_metrics_response_model = console_ns.model(
|
||||
"EvaluationDefaultMetricsResponse",
|
||||
{
|
||||
"default_metrics": fields.List(
|
||||
fields.Nested(console_ns.model("EvaluationDefaultMetricItem", evaluation_default_metric_item_fields)),
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
evaluation_dataset_columns_response_model = console_ns.model(
|
||||
"EvaluationDatasetColumnsResponse",
|
||||
{
|
||||
"columns": fields.List(
|
||||
fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTemplateColumn",
|
||||
{
|
||||
"name": fields.String,
|
||||
"type": fields.String,
|
||||
},
|
||||
)
|
||||
)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_evaluation_target[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (apps or snippets).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
target = db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
elif target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def _load_evaluation_run_request_and_dataset(tenant_id: str) -> tuple[EvaluationRunRequest, bytes, str]:
|
||||
"""Validate the run payload and load the uploaded dataset bytes."""
|
||||
body = request.get_json(force=True)
|
||||
if not body:
|
||||
raise BadRequest("Request body is required.")
|
||||
|
||||
try:
|
||||
run_request = EvaluationRunRequest.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
upload_file = db.session.query(UploadFile).filter_by(id=run_request.file_id, tenant_id=tenant_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("Dataset file not found.")
|
||||
|
||||
try:
|
||||
dataset_content = storage.load_once(upload_file.key)
|
||||
except Exception:
|
||||
raise BadRequest("Failed to read dataset file.")
|
||||
|
||||
if not dataset_content:
|
||||
raise BadRequest("Dataset file is empty.")
|
||||
|
||||
return run_request, dataset_content, upload_file.name
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation configuration for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
metrics config, and judgement conditions.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.get_evaluation_config(session, current_tenant_id, target_type, str(target.id))
|
||||
|
||||
if config is None:
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"default_metrics": None,
|
||||
"customized_metrics": None,
|
||||
"judgment_config": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
|
||||
}
|
||||
|
||||
@console_ns.doc("save_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation configuration saved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def put(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Save evaluation configuration for the target.
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
body = request.get_json(force=True)
|
||||
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
config = EvaluationService.save_evaluation_config(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
data=config_data,
|
||||
)
|
||||
|
||||
return {
|
||||
"evaluation_model": config.evaluation_model,
|
||||
"evaluation_model_provider": config.evaluation_model_provider,
|
||||
"default_metrics": EvaluationService.serialize_console_default_metrics(config.default_metrics_list),
|
||||
"customized_metrics": config.customized_metrics_dict,
|
||||
"judgment_config": EvaluationService.serialize_console_judgment_config(config.judgment_config_dict),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/template-columns")
|
||||
class EvaluationTemplateColumnsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_template_columns")
|
||||
@console_ns.response(200, "Evaluation dataset columns resolved", evaluation_dataset_columns_response_model)
|
||||
@console_ns.response(400, "Invalid request body")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return the dataset template columns implied by the current evaluation config."""
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
config_data = EvaluationConfigData.model_validate(body)
|
||||
except Exception as e:
|
||||
raise BadRequest(f"Invalid request body: {e}")
|
||||
|
||||
return {
|
||||
"columns": EvaluationService.get_dataset_column_names(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
data=config_data,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation run history for the target.
|
||||
|
||||
Returns a paginated list of evaluation runs.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
runs, total = EvaluationService.get_evaluation_runs(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": [_serialize_evaluation_run(run) for run in runs],
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run1")
|
||||
class EvaluationRunApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""
|
||||
Start an evaluation run.
|
||||
|
||||
Expects JSON body with:
|
||||
- file_id: uploaded dataset file ID
|
||||
- evaluation_model: evaluation model name
|
||||
- evaluation_model_provider: evaluation model provider
|
||||
- default_metrics: list of default metric objects
|
||||
- customized_metrics: customized metrics object (optional)
|
||||
- judgment_config: judgment conditions config (optional)
|
||||
"""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
if target_type == EvaluationTargetType.APPS.value:
|
||||
evaluation_run = EvaluationService.start_stub_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
else:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/run")
|
||||
class EvaluationRunRealApi(Resource):
|
||||
@console_ns.doc("start_evaluation_run_real")
|
||||
@console_ns.response(200, "Evaluation run started")
|
||||
@console_ns.response(400, "Invalid request")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet, Dataset], target_type: str):
|
||||
"""Start the real evaluation execution flow on the temporary dev path."""
|
||||
current_account, current_tenant_id = current_account_with_tenant()
|
||||
run_request, dataset_content, dataset_filename = _load_evaluation_run_request_and_dataset(current_tenant_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
evaluation_run = EvaluationService.start_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
target_type=target_type,
|
||||
target_id=str(target.id),
|
||||
account_id=str(current_account.id),
|
||||
dataset_file_content=dataset_content,
|
||||
dataset_filename=dataset_filename,
|
||||
run_request=run_request,
|
||||
)
|
||||
return _serialize_evaluation_run(evaluation_run), 200
|
||||
except EvaluationFrameworkNotConfiguredError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except EvaluationMaxConcurrentRunsError as e:
|
||||
return {"message": str(e.description)}, 429
|
||||
except EvaluationDatasetInvalidError as e:
|
||||
return {"message": str(e.description)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>")
|
||||
class EvaluationRunDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_run_detail")
|
||||
@console_ns.response(200, "Evaluation run detail retrieved")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""
|
||||
Get evaluation run detail including items.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 50, type=int)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.get_evaluation_run_detail(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
items, total_items = EvaluationService.get_evaluation_run_items(
|
||||
session=session,
|
||||
run_id=run_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return {
|
||||
"run": _serialize_evaluation_run(run),
|
||||
"items": {
|
||||
"data": [_serialize_evaluation_run_item(item) for item in items],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
}
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/runs/<uuid:run_id>/cancel")
|
||||
class EvaluationRunCancelApi(Resource):
|
||||
@console_ns.doc("cancel_evaluation_run")
|
||||
@console_ns.response(200, "Evaluation run cancelled")
|
||||
@console_ns.response(404, "Run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str, run_id: str):
|
||||
"""Cancel a running evaluation."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
run_id = str(run_id)
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = EvaluationService.cancel_evaluation_run(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
run_id=run_id,
|
||||
)
|
||||
return _serialize_evaluation_run(run)
|
||||
except EvaluationNotFoundError as e:
|
||||
return {"message": str(e.description)}, 404
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/metrics")
|
||||
class EvaluationMetricsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics retrieved")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get available evaluation metrics for the current framework.
|
||||
"""
|
||||
result = {}
|
||||
for category in EvaluationCategory:
|
||||
if category in EvaluationService.CONSOLE_DISABLED_CATEGORIES:
|
||||
continue
|
||||
result[category.value] = EvaluationService.get_supported_metrics(category)
|
||||
return {"metrics": result}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/default-metrics")
|
||||
class EvaluationDefaultMetricsApi(Resource):
|
||||
@console_ns.doc(
|
||||
"get_evaluation_default_metrics_with_nodes",
|
||||
description=(
|
||||
"List default metrics supported by the current evaluation framework with matching nodes "
|
||||
"from the target's published workflow only (draft is ignored)."
|
||||
),
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Default metrics and node candidates for the published workflow",
|
||||
evaluation_default_metrics_response_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
default_metrics = EvaluationService.get_default_metrics_with_nodes_for_published_target(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
return {
|
||||
"default_metrics": [
|
||||
m.model_dump() for m in EvaluationService.filter_console_default_metrics(default_metrics)
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/node-info")
|
||||
class EvaluationNodeInfoApi(Resource):
|
||||
@console_ns.doc("get_evaluation_node_info")
|
||||
@console_ns.response(200, "Node info grouped by metric")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""Return workflow/snippet node info grouped by requested metrics.
|
||||
|
||||
Request body (JSON):
|
||||
- metrics: list[str] | None – metric names to query; omit or pass
|
||||
an empty list to get all nodes under key ``"all"``.
|
||||
|
||||
Response:
|
||||
``{metric_or_all: [{"node_id": ..., "type": ..., "title": ...}, ...]}``
|
||||
"""
|
||||
body = request.get_json(silent=True) or {}
|
||||
metrics: list[str] | None = body.get("metrics") or None
|
||||
|
||||
result = EvaluationService.get_nodes_for_metrics(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
metrics=metrics,
|
||||
)
|
||||
if not metrics:
|
||||
result = {
|
||||
"all": [
|
||||
node
|
||||
for node in result.get("all", [])
|
||||
if node.get("type") not in EvaluationService.CONSOLE_DISABLED_CATEGORIES
|
||||
]
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
metric: nodes
|
||||
for metric, nodes in result.items()
|
||||
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/evaluation/available-metrics")
|
||||
class EvaluationAvailableMetricsApi(Resource):
|
||||
@console_ns.doc("get_available_evaluation_metrics")
|
||||
@console_ns.response(200, "Available metrics list")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Return the centrally-defined list of evaluation metrics."""
|
||||
return {
|
||||
"metrics": [
|
||||
metric
|
||||
for metric in EvaluationService.get_available_metrics()
|
||||
if metric not in EvaluationService.CONSOLE_DISABLED_METRICS
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
graph = {}
|
||||
if target_type == EvaluationTargetType.SNIPPETS.value and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/available-evaluation-workflows")
|
||||
class AvailableEvaluationWorkflowsApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@console_ns.doc("list_available_evaluation_workflows")
|
||||
@console_ns.doc(description="List published evaluation workflows in the current workspace (all apps)")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Available evaluation workflows retrieved",
|
||||
available_evaluation_workflow_pagination_model,
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self):
|
||||
"""List published evaluation-type workflows for the current tenant (cross-app)."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
page = args.page
|
||||
limit = args.limit
|
||||
user_id = args.user_id
|
||||
named_only = args.named_only
|
||||
keyword = args.keyword
|
||||
|
||||
if user_id and user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.list_published_evaluation_workflows(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
keyword=keyword,
|
||||
)
|
||||
|
||||
app_ids = {w.app_id for w in workflows}
|
||||
if app_ids:
|
||||
apps = session.scalars(select(App).where(App.id.in_(app_ids))).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
else:
|
||||
app_names = {}
|
||||
|
||||
items = []
|
||||
for wf in workflows:
|
||||
items.append(
|
||||
{
|
||||
"id": wf.id,
|
||||
"app_id": wf.app_id,
|
||||
"app_name": app_names.get(wf.app_id, ""),
|
||||
"type": wf.type.value,
|
||||
"kind": wf.kind_or_standard,
|
||||
"version": wf.version,
|
||||
"marked_name": wf.marked_name,
|
||||
"marked_comment": wf.marked_comment,
|
||||
"hash": wf.unique_hash,
|
||||
"created_by": wf.created_by_account,
|
||||
"created_at": wf.created_at,
|
||||
"updated_by": wf.updated_by_account,
|
||||
"updated_at": wf.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
return (
|
||||
marshal(
|
||||
{"items": items, "page": page, "limit": limit, "has_more": has_more},
|
||||
available_evaluation_workflow_pagination_fields,
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/evaluation-workflows/<string:workflow_id>/associated-targets")
|
||||
class EvaluationWorkflowAssociatedTargetsApi(Resource):
|
||||
@console_ns.doc("list_evaluation_workflow_associated_targets")
|
||||
@console_ns.doc(
|
||||
description="List targets (apps / snippets / knowledge bases) that use the given workflow as customized metrics"
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, workflow_id: str):
|
||||
"""Return all evaluation targets that reference this workflow as customized metrics."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
configs = EvaluationService.list_targets_by_customized_workflow(
|
||||
session=session,
|
||||
tenant_id=current_tenant_id,
|
||||
customized_workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
target_ids_by_type: dict[str, list[str]] = {}
|
||||
for cfg in configs:
|
||||
target_ids_by_type.setdefault(cfg.target_type, []).append(cfg.target_id)
|
||||
|
||||
app_names: dict[str, str] = {}
|
||||
if EvaluationTargetType.APPS.value in target_ids_by_type:
|
||||
apps = session.scalars(
|
||||
select(App).where(App.id.in_(target_ids_by_type[EvaluationTargetType.APPS.value]))
|
||||
).all()
|
||||
app_names = {a.id: a.name for a in apps}
|
||||
|
||||
snippet_names: dict[str, str] = {}
|
||||
if "snippets" in target_ids_by_type:
|
||||
snippets = session.scalars(
|
||||
select(CustomizedSnippet).where(CustomizedSnippet.id.in_(target_ids_by_type["snippets"]))
|
||||
).all()
|
||||
snippet_names = {s.id: s.name for s in snippets}
|
||||
|
||||
dataset_names: dict[str, str] = {}
|
||||
if "knowledge_base" in target_ids_by_type:
|
||||
datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.id.in_(target_ids_by_type["knowledge_base"]))
|
||||
).all()
|
||||
dataset_names = {d.id: d.name for d in datasets}
|
||||
|
||||
items = []
|
||||
for cfg in configs:
|
||||
name = ""
|
||||
if cfg.target_type == EvaluationTargetType.APPS.value:
|
||||
name = app_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == EvaluationTargetType.SNIPPETS.value:
|
||||
name = snippet_names.get(cfg.target_id, "")
|
||||
elif cfg.target_type == "knowledge_base":
|
||||
name = dataset_names.get(cfg.target_id, "")
|
||||
|
||||
items.append(
|
||||
{
|
||||
"target_type": cfg.target_type,
|
||||
"target_id": cfg.target_id,
|
||||
"target_name": name,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": items}, 200
|
||||
|
||||
|
||||
# ---- Serialization Helpers ----
|
||||
|
||||
|
||||
def _serialize_evaluation_run(run: EvaluationRun) -> dict[str, object]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"target_type": run.target_type,
|
||||
"target_id": run.target_id,
|
||||
"evaluation_config_id": run.evaluation_config_id,
|
||||
"status": run.status,
|
||||
"dataset_file_id": run.dataset_file_id,
|
||||
"result_file_id": run.result_file_id,
|
||||
"total_items": run.total_items,
|
||||
"completed_items": run.completed_items,
|
||||
"failed_items": run.failed_items,
|
||||
"progress": run.progress,
|
||||
"metrics_summary": run.metrics_summary_dict,
|
||||
"error": run.error,
|
||||
"created_by": run.created_by,
|
||||
"started_at": int(run.started_at.timestamp()) if run.started_at else None,
|
||||
"completed_at": int(run.completed_at.timestamp()) if run.completed_at else None,
|
||||
"created_at": int(run.created_at.timestamp()) if run.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_evaluation_run_item(item: EvaluationRunItem) -> dict[str, object]:
|
||||
return {
|
||||
"id": item.id,
|
||||
"item_index": item.item_index,
|
||||
"inputs": item.inputs_dict,
|
||||
"expected_output": item.expected_output,
|
||||
"actual_output": item.actual_output,
|
||||
"metrics": item.metrics_list,
|
||||
"judgment": item.judgment_dict,
|
||||
"metadata": item.metadata_dict,
|
||||
"error": item.error,
|
||||
"overall_score": item.overall_score,
|
||||
}
|
||||
@ -6,7 +6,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -72,6 +73,7 @@ class ChatMessagePayload(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@ -130,6 +132,7 @@ class CompletionApi(InstalledAppResource):
|
||||
endpoint="installed_app_stop_completion",
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
@ -205,6 +208,7 @@ class ChatApi(InstalledAppResource):
|
||||
endpoint="installed_app_stop_chat_completion",
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -34,6 +34,7 @@ class ConversationListQuery(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
register_response_schema_models(console_ns, ResultResponse)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -89,6 +90,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
endpoint="installed_app_conversation",
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -142,6 +144,7 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
endpoint="installed_app_conversation_pin",
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@ -165,6 +168,7 @@ class ConversationPinApi(InstalledAppResource):
|
||||
endpoint="installed_app_conversation_unpin",
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
@ -8,7 +8,8 @@ from pydantic import BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import and_, select
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleMessageResponse, SimpleResultMessageResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
@ -16,6 +17,7 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
@ -105,9 +107,7 @@ class InstalledAppResponse(ResponseModel):
|
||||
@field_validator("last_used_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class InstalledAppListResponse(ResponseModel):
|
||||
@ -123,6 +123,7 @@ register_schema_models(
|
||||
InstalledAppResponse,
|
||||
InstalledAppListResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleMessageResponse, SimpleResultMessageResponse)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps")
|
||||
@ -210,6 +211,7 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -259,6 +261,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
use InstalledAppResource to apply default decorators and get installed_app
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
@ -269,6 +272,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
|
||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from pydantic import BaseModel, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
CompletionRequestError,
|
||||
@ -49,6 +49,7 @@ class MoreLikeThisQuery(BaseModel):
|
||||
|
||||
|
||||
register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, MoreLikeThisQuery)
|
||||
register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsResponse)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -93,6 +94,7 @@ class MessageListApi(InstalledAppResource):
|
||||
)
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
@ -166,6 +168,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
endpoint="installed_app_suggested_question",
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
||||
@ -64,28 +64,15 @@ class RecommendedAppListResponse(ResponseModel):
|
||||
categories: list[str]
|
||||
|
||||
|
||||
class LearnDifyAppListResponse(ResponseModel):
|
||||
recommended_apps: list[RecommendedAppResponse]
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
RecommendedAppsQuery,
|
||||
RecommendedAppInfoResponse,
|
||||
RecommendedAppResponse,
|
||||
RecommendedAppListResponse,
|
||||
LearnDifyAppListResponse,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_language(language: str | None) -> str:
|
||||
if language and language in languages:
|
||||
return language
|
||||
if current_user and current_user.interface_language:
|
||||
return current_user.interface_language
|
||||
return languages[0]
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps")
|
||||
class RecommendedAppListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@ -95,7 +82,13 @@ class RecommendedAppListApi(Resource):
|
||||
def get(self):
|
||||
# language args
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language_prefix = _resolve_language(args.language)
|
||||
language = args.language
|
||||
if language and language in languages:
|
||||
language_prefix = language
|
||||
elif current_user and current_user.interface_language:
|
||||
language_prefix = current_user.interface_language
|
||||
else:
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
@ -103,22 +96,6 @@ class RecommendedAppListApi(Resource):
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/learn-dify")
|
||||
class LearnDifyAppListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RecommendedAppsQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[LearnDifyAppListResponse.__name__])
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = RecommendedAppsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
language_prefix = _resolve_language(args.language)
|
||||
|
||||
return LearnDifyAppListResponse.model_validate(
|
||||
RecommendedAppService.get_learn_dify_apps(language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/explore/apps/<uuid:app_id>")
|
||||
class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
|
||||
@ -3,7 +3,7 @@ from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@ -14,6 +14,7 @@ from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
register_response_schema_models(console_ns, ResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
@ -42,6 +43,7 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
@ -62,6 +64,7 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
"/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", endpoint="installed_app_saved_message"
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
def delete(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
|
||||
@ -3,7 +3,8 @@ import logging
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -34,6 +35,7 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
@ -78,6 +80,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop")
|
||||
class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
"""
|
||||
Stop workflow task
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
@ -40,12 +41,6 @@ def _mask_api_key(api_key: str) -> str:
|
||||
return api_key[:3] + "******" + api_key[-3:]
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class APIBasedExtensionResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -61,7 +56,7 @@ class APIBasedExtensionResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, APIBasedExtensionPayload, CodeBasedExtensionResponse, APIBasedExtensionResponse)
|
||||
@ -75,6 +70,21 @@ def _serialize_api_based_extension(extension: APIBasedExtension) -> dict[str, An
|
||||
return APIBasedExtensionResponse.model_validate(extension, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
def _serialize_saved_api_based_extension(extension: APIBasedExtension, api_key: str) -> dict[str, Any]:
|
||||
"""Serialize a saved extension with the plaintext key used for response masking only.
|
||||
|
||||
APIBasedExtensionService.save mutates the ORM object to hold the encrypted token before returning it. The response
|
||||
contract, however, should match list/detail responses, where api_key is masked from the decrypted token.
|
||||
"""
|
||||
return APIBasedExtensionResponse(
|
||||
id=extension.id,
|
||||
name=extension.name,
|
||||
api_endpoint=extension.api_endpoint,
|
||||
api_key=api_key,
|
||||
created_at=to_timestamp(extension.created_at),
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/code-based-extension")
|
||||
class CodeBasedExtensionAPI(Resource):
|
||||
@console_ns.doc("get_code_based_extension")
|
||||
@ -130,7 +140,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data))
|
||||
return _serialize_saved_api_based_extension(APIBasedExtensionService.save(extension_data), payload.api_key), 201
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
@ -165,14 +175,19 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
api_key_for_response = extension_data_from_db.api_key
|
||||
|
||||
extension_data_from_db.name = payload.name
|
||||
extension_data_from_db.api_endpoint = payload.api_endpoint
|
||||
|
||||
if payload.api_key != HIDDEN_VALUE:
|
||||
extension_data_from_db.api_key = payload.api_key
|
||||
api_key_for_response = payload.api_key
|
||||
|
||||
return _serialize_api_based_extension(APIBasedExtensionService.save(extension_data_from_db))
|
||||
return _serialize_saved_api_based_extension(
|
||||
APIBasedExtensionService.save(extension_data_from_db),
|
||||
api_key_for_response,
|
||||
)
|
||||
|
||||
@console_ns.doc("delete_api_based_extension")
|
||||
@console_ns.doc(description="Delete API-based extension")
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureService
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
class FeatureApi(Resource):
|
||||
@ -15,7 +18,7 @@ class FeatureApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
|
||||
console_ns.models[FeatureModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -35,9 +38,7 @@ class SystemFeatureApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
|
||||
),
|
||||
console_ns.models[SystemFeatureModel.__name__],
|
||||
)
|
||||
def get(self):
|
||||
"""Get system-wide feature configuration
|
||||
|
||||
@ -15,7 +15,8 @@ from controllers.common.errors import (
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import AllowedExtensionsResponse, TextContentResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
@ -29,6 +30,7 @@ from services.file_service import FileService
|
||||
from . import console_ns
|
||||
|
||||
register_schema_models(console_ns, UploadConfig, FileResponse)
|
||||
register_response_schema_models(console_ns, AllowedExtensionsResponse, TextContentResponse)
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@ -103,9 +105,11 @@ class FilePreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
@ -114,5 +118,6 @@ class FileSupportTypeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AllowedExtensionsResponse.__name__])
|
||||
def get(self):
|
||||
return {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}
|
||||
|
||||
@ -5,6 +5,8 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -48,6 +50,9 @@ class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@ -110,6 +115,7 @@ class NotificationDismissApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
@ -24,8 +25,13 @@ class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
register_schema_models(console_ns, RemoteFileUploadPayload)
|
||||
register_response_schema_models(console_ns, FileWithSignedUrl, RemoteFileInfo)
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class GetRemoteFileInfo(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
@ -41,6 +47,8 @@ class GetRemoteFileInfo(Resource):
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUpload(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
|
||||
@login_required
|
||||
def post(self):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
|
||||
@ -1,142 +0,0 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
is_published: bool | None = Field(default=None, description="Filter by published status")
|
||||
creators: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
|
||||
@field_validator("creators", mode="before")
|
||||
@classmethod
|
||||
def parse_creators(cls, value: object) -> list[str] | None:
|
||||
"""Normalize creators filter from query string or list input."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return [creator.strip() for creator in value.split(",") if creator.strip()] or None
|
||||
if isinstance(value, list):
|
||||
return [str(creator).strip() for creator in value if str(creator).strip()] or None
|
||||
return None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description="Ignored. Snippet workflows do not persist conversation variables.",
|
||||
)
|
||||
input_fields: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetWorkflowListQuery(BaseModel):
|
||||
"""Query parameters for listing snippet published workflows."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetImportPayload(BaseModel):
|
||||
"""Payload for importing snippet from DSL."""
|
||||
|
||||
mode: str = Field(..., description="Import mode: yaml-content or yaml-url")
|
||||
yaml_content: str | None = Field(default=None, description="YAML content (required for yaml-content mode)")
|
||||
yaml_url: str | None = Field(default=None, description="YAML URL (required for yaml-url mode)")
|
||||
name: str | None = Field(default=None, description="Override snippet name")
|
||||
description: str | None = Field(default=None, description="Override snippet description")
|
||||
snippet_id: str | None = Field(default=None, description="Snippet ID to update (optional)")
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
"""Query parameter for including secret variables in export."""
|
||||
|
||||
include_secret: str = Field(default="false", description="Whether to include secret variables")
|
||||
@ -1,617 +0,0 @@
|
||||
# import logging
|
||||
# from collections.abc import Callable
|
||||
# from functools import wraps
|
||||
|
||||
# from flask import request
|
||||
# from flask_restx import Resource, fields, marshal, marshal_with
|
||||
# from sqlalchemy.orm import Session
|
||||
# from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
# from controllers.common.schema import register_schema_models
|
||||
# from controllers.console import console_ns
|
||||
# from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
# from controllers.console.app.workflow import (
|
||||
# RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE,
|
||||
# workflow_model,
|
||||
# workflow_pagination_model,
|
||||
# )
|
||||
# from controllers.console.app.workflow_run import (
|
||||
# workflow_run_detail_model,
|
||||
# workflow_run_node_execution_list_model,
|
||||
# workflow_run_node_execution_model,
|
||||
# workflow_run_pagination_model,
|
||||
# )
|
||||
# from controllers.console.snippets.payloads import (
|
||||
# PublishWorkflowPayload,
|
||||
# SnippetDraftNodeRunPayload,
|
||||
# SnippetDraftRunPayload,
|
||||
# SnippetDraftSyncPayload,
|
||||
# SnippetIterationNodeRunPayload,
|
||||
# SnippetLoopNodeRunPayload,
|
||||
# SnippetWorkflowListQuery,
|
||||
# WorkflowRunQuery,
|
||||
# )
|
||||
# from controllers.console.wraps import (
|
||||
# account_initialization_required,
|
||||
# edit_permission_required,
|
||||
# setup_required,
|
||||
# )
|
||||
# from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
# from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
# from extensions.ext_database import db
|
||||
# from extensions.ext_redis import redis_client
|
||||
# from graphon.graph_engine.manager import GraphEngineManager
|
||||
# from libs import helper
|
||||
# from libs.helper import TimestampField
|
||||
# from libs.login import current_account_with_tenant, login_required
|
||||
# from models.snippet import CustomizedSnippet
|
||||
# from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError
|
||||
# from services.snippet_generate_service import SnippetGenerateService
|
||||
# from services.snippet_service import SnippetService
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# # Register Pydantic models with Swagger
|
||||
# register_schema_models(
|
||||
# console_ns,
|
||||
# SnippetDraftSyncPayload,
|
||||
# SnippetDraftNodeRunPayload,
|
||||
# SnippetDraftRunPayload,
|
||||
# SnippetIterationNodeRunPayload,
|
||||
# SnippetLoopNodeRunPayload,
|
||||
# SnippetWorkflowListQuery,
|
||||
# WorkflowRunQuery,
|
||||
# PublishWorkflowPayload,
|
||||
# )
|
||||
|
||||
|
||||
# snippet_workflow_model = console_ns.clone("SnippetWorkflow", workflow_model, {
|
||||
# "input_fields": fields.Raw(default=[]),
|
||||
# })
|
||||
|
||||
|
||||
# class SnippetNotFoundError(Exception):
|
||||
# """Snippet not found error."""
|
||||
|
||||
# pass
|
||||
|
||||
|
||||
# def get_snippet[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
# """Decorator to fetch and validate snippet access."""
|
||||
|
||||
# @wraps(view_func)
|
||||
# def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# if not kwargs.get("snippet_id"):
|
||||
# raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
# _, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# snippet_id = str(kwargs.get("snippet_id"))
|
||||
# del kwargs["snippet_id"]
|
||||
|
||||
# snippet = SnippetService.get_snippet_by_id(
|
||||
# snippet_id=snippet_id,
|
||||
# tenant_id=current_tenant_id,
|
||||
# )
|
||||
|
||||
# if not snippet:
|
||||
# raise NotFound("Snippet not found")
|
||||
|
||||
# kwargs["snippet"] = snippet
|
||||
|
||||
# return view_func(*args, **kwargs)
|
||||
|
||||
# return decorated_view
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
# class SnippetDraftWorkflowApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_workflow")
|
||||
# @console_ns.response(200, "Draft workflow retrieved successfully", snippet_workflow_model)
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @marshal_with(snippet_workflow_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get draft workflow for snippet."""
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
# if not workflow:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# db.session.expunge(workflow)
|
||||
# workflow.conversation_variables = []
|
||||
# workflow.input_fields = snippet.input_fields_list
|
||||
# return workflow
|
||||
|
||||
# @console_ns.doc("sync_snippet_draft_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
# @console_ns.response(200, "Draft workflow synced successfully")
|
||||
# @console_ns.response(400, "Hash mismatch")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """Sync draft workflow for snippet."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
|
||||
# payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# try:
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.sync_draft_workflow(
|
||||
# snippet=snippet,
|
||||
# graph=payload.graph,
|
||||
# unique_hash=payload.hash,
|
||||
# account=current_user,
|
||||
# input_fields=payload.input_fields,
|
||||
# )
|
||||
# except WorkflowHashNotEqualError:
|
||||
# raise DraftWorkflowNotSync()
|
||||
# except ValueError as e:
|
||||
# return {"message": str(e)}, 400
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "hash": workflow.unique_hash,
|
||||
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
# class SnippetDraftConfigApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_config")
|
||||
# @console_ns.response(200, "Draft config retrieved successfully")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get snippet draft workflow configuration limits."""
|
||||
# return {
|
||||
# "parallel_depth_limit": 3,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
# class SnippetPublishedWorkflowApi(Resource):
|
||||
# @console_ns.doc("get_snippet_published_workflow")
|
||||
# @console_ns.response(200, "Published workflow retrieved successfully", snippet_workflow_model)
|
||||
# @console_ns.response(404, "Snippet not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @marshal_with(snippet_workflow_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get published workflow for snippet."""
|
||||
# if not snippet.is_published:
|
||||
# return None
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
# if workflow:
|
||||
# workflow.input_fields = snippet.input_fields_list
|
||||
|
||||
# return workflow
|
||||
|
||||
# @console_ns.doc("publish_snippet_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
# @console_ns.response(200, "Workflow published successfully")
|
||||
# @console_ns.response(400, "No draft workflow found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """Publish snippet workflow."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# snippet_service = SnippetService()
|
||||
|
||||
# with Session(db.engine) as session:
|
||||
# snippet = session.merge(snippet)
|
||||
# try:
|
||||
# workflow = snippet_service.publish_workflow(
|
||||
# session=session,
|
||||
# snippet=snippet,
|
||||
# account=current_user,
|
||||
# )
|
||||
# workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
# session.commit()
|
||||
# except ValueError as e:
|
||||
# return {"message": str(e)}, 400
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "created_at": workflow_created_at,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
# class SnippetDefaultBlockConfigsApi(Resource):
|
||||
# @console_ns.doc("get_snippet_default_block_configs")
|
||||
# @console_ns.response(200, "Default block configs retrieved successfully")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get default block configurations for snippet workflow."""
|
||||
# snippet_service = SnippetService()
|
||||
# return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows")
|
||||
# class SnippetPublishedAllWorkflowApi(Resource):
|
||||
# @console_ns.expect(console_ns.models[SnippetWorkflowListQuery.__name__])
|
||||
# @console_ns.doc("get_all_snippet_published_workflows")
|
||||
# @console_ns.doc(description="Get all published workflows for a snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID"})
|
||||
# @console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """Get all published workflow versions for snippet."""
|
||||
# args = SnippetWorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# with Session(db.engine) as session:
|
||||
# workflows, has_more = snippet_service.get_all_published_workflows(
|
||||
# session=session,
|
||||
# snippet=snippet,
|
||||
# page=args.page,
|
||||
# limit=args.limit,
|
||||
# )
|
||||
# serialized_workflows = marshal(workflows, workflow_model)
|
||||
|
||||
# return {
|
||||
# "items": serialized_workflows,
|
||||
# "page": args.page,
|
||||
# "limit": args.limit,
|
||||
# "has_more": has_more,
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/<string:workflow_id>/restore")
|
||||
# class SnippetDraftWorkflowRestoreApi(Resource):
|
||||
# @console_ns.doc("restore_snippet_workflow_to_draft")
|
||||
# @console_ns.doc(description="Restore a published snippet workflow version into the draft workflow")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "workflow_id": "Published workflow ID"})
|
||||
# @console_ns.response(200, "Workflow restored successfully")
|
||||
# @console_ns.response(400, "Source workflow must be published")
|
||||
# @console_ns.response(404, "Workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, workflow_id: str):
|
||||
# """Restore a published snippet workflow version into the draft workflow."""
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# snippet_service = SnippetService()
|
||||
|
||||
# try:
|
||||
# workflow = snippet_service.restore_published_workflow_to_draft(
|
||||
# snippet=snippet,
|
||||
# workflow_id=workflow_id,
|
||||
# account=current_user,
|
||||
# )
|
||||
# except IsDraftWorkflowError as exc:
|
||||
# raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc
|
||||
# except WorkflowNotFoundError as exc:
|
||||
# raise NotFound(str(exc)) from exc
|
||||
# except ValueError as exc:
|
||||
# raise BadRequest(str(exc)) from exc
|
||||
|
||||
# return {
|
||||
# "result": "success",
|
||||
# "hash": workflow.unique_hash,
|
||||
# "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
# }
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
# class SnippetWorkflowRunsApi(Resource):
|
||||
# @console_ns.doc("list_snippet_workflow_runs")
|
||||
# @console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_pagination_model)
|
||||
# def get(self, snippet: CustomizedSnippet):
|
||||
# """List workflow runs for snippet."""
|
||||
# query = WorkflowRunQuery.model_validate(
|
||||
# {
|
||||
# "last_id": request.args.get("last_id"),
|
||||
# "limit": request.args.get("limit", type=int, default=20),
|
||||
# }
|
||||
# )
|
||||
# args = {
|
||||
# "last_id": query.last_id,
|
||||
# "limit": query.limit,
|
||||
# }
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
# return result
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
# class SnippetWorkflowRunDetailApi(Resource):
|
||||
# @console_ns.doc("get_snippet_workflow_run_detail")
|
||||
# @console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
# @console_ns.response(404, "Workflow run not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_detail_model)
|
||||
# def get(self, snippet: CustomizedSnippet, run_id):
|
||||
# """Get workflow run detail for snippet."""
|
||||
# run_id = str(run_id)
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
# if not workflow_run:
|
||||
# raise NotFound("Workflow run not found")
|
||||
|
||||
# return workflow_run
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
# class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
# @console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
# @console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet, run_id):
|
||||
# """List node executions for a workflow run."""
|
||||
# run_id = str(run_id)
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
# snippet=snippet,
|
||||
# run_id=run_id,
|
||||
# )
|
||||
|
||||
# return {"data": node_executions}
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftNodeRunApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_node")
|
||||
# @console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_model)
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a single node in snippet draft workflow.
|
||||
|
||||
# Executes a specific node with provided inputs for single-step debugging.
|
||||
# Returns the node execution result including status, outputs, and timing.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# user_inputs = payload.inputs
|
||||
|
||||
# # Get draft workflow for file parsing
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if not draft_workflow:
|
||||
# raise NotFound("Draft workflow not found")
|
||||
|
||||
# files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
# workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
# snippet=snippet,
|
||||
# node_id=node_id,
|
||||
# user_inputs=user_inputs,
|
||||
# account=current_user,
|
||||
# query=payload.query,
|
||||
# files=files,
|
||||
# )
|
||||
|
||||
# return workflow_node_execution
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
# class SnippetDraftNodeLastRunApi(Resource):
|
||||
# @console_ns.doc("get_snippet_draft_node_last_run")
|
||||
# @console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
# @console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @marshal_with(workflow_run_node_execution_model)
|
||||
# def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
# Returns the most recent execution record for the given node,
|
||||
# including status, inputs, outputs, and timing information.
|
||||
# """
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if not draft_workflow:
|
||||
# raise NotFound("Draft workflow not found")
|
||||
|
||||
# node_exec = snippet_service.get_snippet_node_last_run(
|
||||
# snippet=snippet,
|
||||
# workflow=draft_workflow,
|
||||
# node_id=node_id,
|
||||
# )
|
||||
# if node_exec is None:
|
||||
# raise NotFound("Node last run not found")
|
||||
|
||||
# return node_exec
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftRunIterationNodeApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_iteration_node")
|
||||
# @console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a draft workflow iteration node for snippet.
|
||||
|
||||
# Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
# Returns an SSE event stream with iteration progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate_single_iteration(
|
||||
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
# class SnippetDraftRunLoopNodeApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_loop_node")
|
||||
# @console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
# @console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
# @console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
# @console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
# """
|
||||
# Run a draft workflow loop node for snippet.
|
||||
|
||||
# Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
# Returns an SSE event stream with loop progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
# args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate_single_loop(
|
||||
# snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
# class SnippetDraftWorkflowRunApi(Resource):
|
||||
# @console_ns.doc("run_snippet_draft_workflow")
|
||||
# @console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
# @console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
# @console_ns.response(404, "Snippet or draft workflow not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet):
|
||||
# """
|
||||
# Run draft workflow for snippet.
|
||||
|
||||
# Executes the snippet's draft workflow with the provided inputs
|
||||
# and returns an SSE event stream with execution progress and results.
|
||||
# """
|
||||
# current_user, _ = current_account_with_tenant()
|
||||
|
||||
# payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
# args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# try:
|
||||
# response = SnippetGenerateService.generate(
|
||||
# snippet=snippet,
|
||||
# user=current_user,
|
||||
# args=args,
|
||||
# invoke_from=InvokeFrom.DEBUGGER,
|
||||
# streaming=True,
|
||||
# )
|
||||
|
||||
# return helper.compact_generate_response(response)
|
||||
# except ValueError as e:
|
||||
# raise e
|
||||
# except Exception:
|
||||
# logger.exception("internal server error.")
|
||||
# raise InternalServerError()
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
# class SnippetWorkflowTaskStopApi(Resource):
|
||||
# @console_ns.doc("stop_snippet_workflow_task")
|
||||
# @console_ns.response(200, "Task stopped successfully")
|
||||
# @console_ns.response(404, "Snippet not found")
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
# """
|
||||
# Stop a running snippet workflow task.
|
||||
|
||||
# Uses both the legacy stop flag mechanism and the graph engine
|
||||
# command channel for backward compatibility.
|
||||
# """
|
||||
# # Stop using both mechanisms for backward compatibility
|
||||
# # Legacy stop flag mechanism (without user check)
|
||||
# AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# # New graph engine command channel mechanism
|
||||
# GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
|
||||
# return {"result": "success"}
|
||||
@ -1,316 +0,0 @@
|
||||
# """
|
||||
# Snippet draft workflow variable APIs.
|
||||
|
||||
# Mirrors console app routes under /apps/.../workflows/draft/variables for snippet scope,
|
||||
# using CustomizedSnippet.id as WorkflowDraftVariable.app_id (same invariant as snippet execution).
|
||||
|
||||
# Snippet workflows do not expose system variables (`node_id == sys`) or conversation variables
|
||||
# (`node_id == conversation`): paginated list queries exclude those rows; single-variable GET/PATCH/DELETE/reset
|
||||
# reject them; `GET .../system-variables` and `GET .../conversation-variables` return empty lists for API parity.
|
||||
# Other routes mirror `workflow_draft_variable` app APIs under `/snippets/...`.
|
||||
# """
|
||||
|
||||
# from collections.abc import Callable
|
||||
# from functools import wraps
|
||||
# from typing import Any
|
||||
|
||||
# from flask import Response, request
|
||||
# from flask_restx import Resource, marshal, marshal_with
|
||||
# from sqlalchemy.orm import Session
|
||||
|
||||
# from controllers.console import console_ns
|
||||
# from controllers.console.app.error import DraftWorkflowNotExist
|
||||
# from controllers.console.app.workflow_draft_variable import (
|
||||
# WorkflowDraftVariableListQuery,
|
||||
# WorkflowDraftVariableUpdatePayload,
|
||||
# _ensure_variable_access,
|
||||
# _file_access_controller,
|
||||
# validate_node_id,
|
||||
# workflow_draft_variable_list_model,
|
||||
# workflow_draft_variable_list_without_value_model,
|
||||
# workflow_draft_variable_model,
|
||||
# )
|
||||
# from controllers.console.snippets.snippet_workflow import get_snippet
|
||||
# from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
# from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
# from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
# from extensions.ext_database import db
|
||||
# from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
# from factories.variable_factory import build_segment_with_type
|
||||
# from graphon.variables.types import SegmentType
|
||||
# from libs.login import current_user, login_required
|
||||
# from models.snippet import CustomizedSnippet
|
||||
# from models.workflow import WorkflowDraftVariable
|
||||
# from services.snippet_service import SnippetService
|
||||
# from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
# _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS: frozenset[str] = frozenset(
|
||||
# {SYSTEM_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID}
|
||||
# )
|
||||
|
||||
|
||||
# def _ensure_snippet_draft_variable_row_allowed(
|
||||
# *,
|
||||
# variable: WorkflowDraftVariable,
|
||||
# variable_id: str,
|
||||
# ) -> None:
|
||||
# """Snippet scope only supports canvas-node draft variables; treat sys/conversation rows as not found."""
|
||||
# if variable.node_id in _SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS:
|
||||
# raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
|
||||
# def _snippet_draft_var_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
# """Setup, auth, snippet resolution, and tenant edit permission (same stack as snippet workflow APIs)."""
|
||||
|
||||
# @setup_required
|
||||
# @login_required
|
||||
# @account_initialization_required
|
||||
# @get_snippet
|
||||
# @edit_permission_required
|
||||
# @wraps(f)
|
||||
# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# return f(*args, **kwargs)
|
||||
|
||||
# return wrapper
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables")
|
||||
# class SnippetWorkflowVariableCollectionApi(Resource):
|
||||
# @console_ns.expect(console_ns.models[WorkflowDraftVariableListQuery.__name__])
|
||||
# @console_ns.doc("get_snippet_workflow_variables")
|
||||
# @console_ns.doc(description="List draft workflow variables without values (paginated, snippet scope)")
|
||||
# @console_ns.response(
|
||||
# 200,
|
||||
# "Workflow variables retrieved successfully",
|
||||
# workflow_draft_variable_list_without_value_model,
|
||||
# )
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# args = WorkflowDraftVariableListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# snippet_service = SnippetService()
|
||||
# if snippet_service.get_draft_workflow(snippet=snippet) is None:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
# workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
# app_id=snippet.id,
|
||||
# page=args.page,
|
||||
# limit=args.limit,
|
||||
# user_id=current_user.id,
|
||||
# exclude_node_ids=_SNIPPET_EXCLUDED_DRAFT_VARIABLE_NODE_IDS,
|
||||
# )
|
||||
|
||||
# return workflow_vars
|
||||
|
||||
# @console_ns.doc("delete_snippet_workflow_variables")
|
||||
# @console_ns.doc(description="Delete all draft workflow variables for the current user (snippet scope)")
|
||||
# @console_ns.response(204, "Workflow variables deleted successfully")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet) -> Response:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# draft_var_srv.delete_user_workflow_variables(snippet.id, user_id=current_user.id)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
# class SnippetNodeVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_node_variables")
|
||||
# @console_ns.doc(description="Get variables for a specific node (snippet draft workflow)")
|
||||
# @console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet, node_id: str) -> WorkflowDraftVariableList:
|
||||
# validate_node_id(node_id)
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=session)
|
||||
# node_vars = draft_var_srv.list_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
|
||||
# return node_vars
|
||||
|
||||
# @console_ns.doc("delete_snippet_node_variables")
|
||||
# @console_ns.doc(description="Delete all variables for a specific node (snippet draft workflow)")
|
||||
# @console_ns.response(204, "Node variables deleted successfully")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet, node_id: str) -> Response:
|
||||
# validate_node_id(node_id)
|
||||
# srv = WorkflowDraftVariableService(db.session())
|
||||
# srv.delete_node_variables(snippet.id, node_id, user_id=current_user.id)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
# class SnippetVariableApi(Resource):
|
||||
# @console_ns.doc("get_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Get a specific draft workflow variable (snippet scope)")
|
||||
# @console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_model)
|
||||
# def get(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
# return variable
|
||||
|
||||
# @console_ns.doc("update_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Update a draft workflow variable (snippet scope)")
|
||||
# @console_ns.expect(console_ns.models[WorkflowDraftVariableUpdatePayload.__name__])
|
||||
# @console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_model)
|
||||
# def patch(self, snippet: CustomizedSnippet, variable_id: str) -> WorkflowDraftVariable:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
# new_name = args_model.name
|
||||
# raw_value = args_model.value
|
||||
# if new_name is None and raw_value is None:
|
||||
# return variable
|
||||
|
||||
# new_value = None
|
||||
# if raw_value is not None:
|
||||
# if variable.value_type == SegmentType.FILE:
|
||||
# if not isinstance(raw_value, dict):
|
||||
# raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
# raw_value = build_from_mapping(
|
||||
# mapping=raw_value,
|
||||
# tenant_id=snippet.tenant_id,
|
||||
# access_controller=_file_access_controller,
|
||||
# )
|
||||
# elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
# if not isinstance(raw_value, list):
|
||||
# raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
# if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
# raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
# raw_value = build_from_mappings(
|
||||
# mappings=raw_value,
|
||||
# tenant_id=snippet.tenant_id,
|
||||
# access_controller=_file_access_controller,
|
||||
# )
|
||||
# new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
# draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
# db.session.commit()
|
||||
# return variable
|
||||
|
||||
# @console_ns.doc("delete_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Delete a draft workflow variable (snippet scope)")
|
||||
# @console_ns.response(204, "Variable deleted successfully")
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def delete(self, snippet: CustomizedSnippet, variable_id: str) -> Response:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
# draft_var_srv.delete_variable(variable)
|
||||
# db.session.commit()
|
||||
# return Response("", 204)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
# class SnippetVariableResetApi(Resource):
|
||||
# @console_ns.doc("reset_snippet_workflow_variable")
|
||||
# @console_ns.doc(description="Reset a draft workflow variable to its default value (snippet scope)")
|
||||
# @console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
# @console_ns.response(204, "Variable reset (no content)")
|
||||
# @console_ns.response(404, "Variable not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def put(self, snippet: CustomizedSnippet, variable_id: str) -> Response | Any:
|
||||
# draft_var_srv = WorkflowDraftVariableService(session=db.session())
|
||||
# snippet_service = SnippetService()
|
||||
# draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if draft_workflow is None:
|
||||
# raise NotFoundError(
|
||||
# f"Draft workflow not found, snippet_id={snippet.id}",
|
||||
# )
|
||||
# variable = _ensure_variable_access(
|
||||
# variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
# app_id=snippet.id,
|
||||
# variable_id=variable_id,
|
||||
# )
|
||||
# _ensure_snippet_draft_variable_row_allowed(variable=variable, variable_id=variable_id)
|
||||
|
||||
# resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
# db.session.commit()
|
||||
# if resetted is None:
|
||||
# return Response("", 204)
|
||||
# return marshal(resetted, workflow_draft_variable_model)
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/conversation-variables")
|
||||
# class SnippetConversationVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_conversation_variables")
|
||||
# @console_ns.doc(
|
||||
# description="Conversation variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
# )
|
||||
# @console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/system-variables")
|
||||
# class SnippetSystemVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_system_variables")
|
||||
# @console_ns.doc(
|
||||
# description="System variables are not used in snippet workflows; returns an empty list for API parity"
|
||||
# )
|
||||
# @console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# @marshal_with(workflow_draft_variable_list_model)
|
||||
# def get(self, snippet: CustomizedSnippet) -> WorkflowDraftVariableList:
|
||||
# return WorkflowDraftVariableList(variables=[])
|
||||
|
||||
|
||||
# @console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/environment-variables")
|
||||
# class SnippetEnvironmentVariableCollectionApi(Resource):
|
||||
# @console_ns.doc("get_snippet_environment_variables")
|
||||
# @console_ns.doc(description="Get environment variables from snippet draft workflow graph")
|
||||
# @console_ns.response(200, "Environment variables retrieved successfully")
|
||||
# @console_ns.response(404, "Draft workflow not found")
|
||||
# @_snippet_draft_var_prerequisite
|
||||
# def get(self, snippet: CustomizedSnippet) -> dict[str, list[dict[str, Any]]]:
|
||||
# snippet_service = SnippetService()
|
||||
# workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
# if workflow is None:
|
||||
# raise DraftWorkflowNotExist()
|
||||
|
||||
# env_vars_list: list[dict[str, Any]] = []
|
||||
# for v in workflow.environment_variables:
|
||||
# env_vars_list.append(
|
||||
# {
|
||||
# "id": v.id,
|
||||
# "type": "env",
|
||||
# "name": v.name,
|
||||
# "description": v.description,
|
||||
# "selector": v.selector,
|
||||
# "value_type": v.value_type.exposed_type().value,
|
||||
# "value": v.value,
|
||||
# "edited": False,
|
||||
# "visible": True,
|
||||
# "editable": True,
|
||||
# }
|
||||
# )
|
||||
|
||||
# return {"items": env_vars_list}
|
||||
@ -5,7 +5,8 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
@ -25,6 +26,10 @@ class TagBasePayload(BaseModel):
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagUpdateRequestPayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
@ -68,11 +73,13 @@ class TagResponse(ResponseModel):
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagUpdateRequestPayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/tags")
|
||||
@ -97,6 +104,7 @@ class TagListApi(Resource):
|
||||
return serialized_tags, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TagResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -118,7 +126,8 @@ class TagListApi(Resource):
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[TagUpdateRequestPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[TagResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -129,8 +138,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@ -144,6 +153,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@console_ns.response(204, "Tag deleted successfully")
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
@ -198,6 +208,7 @@ class TagBindingCollectionApi(Resource):
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -212,6 +223,7 @@ class TagBindingRemoveApi(Resource):
|
||||
@console_ns.doc("remove_tag_bindings")
|
||||
@console_ns.doc(description="Remove one or more tag bindings from a target.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -12,7 +12,13 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import (
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -42,7 +48,7 @@ from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
@ -185,12 +191,6 @@ def _serialize_account(account) -> dict[str, Any]:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
if isinstance(value, datetime):
|
||||
return int(value.timestamp())
|
||||
return value
|
||||
|
||||
|
||||
class AccountIntegrateResponse(ResponseModel):
|
||||
provider: str
|
||||
created_at: int | None = None
|
||||
@ -200,7 +200,7 @@ class AccountIntegrateResponse(ResponseModel):
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class AccountIntegrateListResponse(ResponseModel):
|
||||
@ -220,7 +220,7 @@ class EducationStatusResponse(ResponseModel):
|
||||
@field_validator("expire_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_expire_at(cls, value: datetime | int | None) -> int | None:
|
||||
return _to_timestamp(value)
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class EducationAutocompleteResponse(ResponseModel):
|
||||
@ -237,11 +237,19 @@ register_schema_models(
|
||||
EducationStatusResponse,
|
||||
EducationAutocompleteResponse,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/account/init")
|
||||
class AccountInitApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountInitPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
@ -318,6 +326,7 @@ class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -479,6 +488,7 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@ -491,6 +501,7 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@console_ns.route("/account/delete")
|
||||
class AccountDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -511,6 +522,7 @@ class AccountDeleteApi(Resource):
|
||||
@console_ns.route("/account/delete/feedback")
|
||||
class AccountDeleteUpdateFeedbackApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
@ -590,6 +602,7 @@ class EducationAutoCompleteApi(Resource):
|
||||
@console_ns.route("/account/change-email")
|
||||
class ChangeEmailSendEmailApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -655,6 +668,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
@console_ns.route("/account/change-email/validity")
|
||||
class ChangeEmailCheckApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[VerificationTokenResponse.__name__])
|
||||
@enable_change_email
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -771,6 +785,7 @@ class ChangeEmailResetApi(Resource):
|
||||
@console_ns.route("/account/change-email/check-email-unique")
|
||||
class CheckEmailUnique(Resource):
|
||||
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
|
||||
@ -6,7 +6,8 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.common.fields import SimpleResultDataResponse, VerificationTokenResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
@ -30,14 +31,13 @@ from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
emails: list[str] = Field(default_factory=list)
|
||||
role: str
|
||||
role: TenantAccountRole
|
||||
language: str | None = None
|
||||
|
||||
|
||||
@ -69,20 +69,9 @@ register_schema_models(
|
||||
OwnerTransferCheckPayload,
|
||||
OwnerTransferPayload,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultDataResponse, VerificationTokenResponse)
|
||||
|
||||
|
||||
def _serialize_member_roles(current_role: str | None, member_roles: list[enterprise_rbac_service.MemberRoleSummary]) -> list[dict[str, str]]:
|
||||
if member_roles:
|
||||
return [{"id": role.id, "name": role.name} for role in member_roles]
|
||||
if current_role:
|
||||
return [{"id": current_role, "name": current_role}]
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_enum_value(value: object) -> str:
|
||||
normalized = getattr(value, "value", value)
|
||||
return str(normalized) if normalized is not None else ""
|
||||
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
@ -102,36 +91,7 @@ class MemberListApi(Resource):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
member_ids = [member.id for member in members]
|
||||
member_roles = enterprise_rbac_service.RBACService.MemberRoles.batch_get(
|
||||
str(current_user.current_tenant.id),
|
||||
current_user.id,
|
||||
member_ids,
|
||||
)
|
||||
roles_map = {item.account_id: item.roles for item in member_roles}
|
||||
else:
|
||||
roles_map = {}
|
||||
|
||||
serialized_members = []
|
||||
for member in members:
|
||||
current_role = _normalize_enum_value(member.current_role)
|
||||
serialized_members.append(
|
||||
{
|
||||
"id": member.id,
|
||||
"name": member.name,
|
||||
"email": member.email,
|
||||
"avatar": member.avatar,
|
||||
"last_login_at": member.last_login_at,
|
||||
"last_active_at": member.last_active_at,
|
||||
"created_at": member.created_at,
|
||||
"role": current_role,
|
||||
"roles": _serialize_member_roles(current_role, roles_map.get(member.id, [])),
|
||||
"status": _normalize_enum_value(member.status),
|
||||
}
|
||||
)
|
||||
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(serialized_members)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = AccountWithRoleList(accounts=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
|
||||
@ -152,9 +112,8 @@ class MemberInviteEmailApi(Resource):
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
if not TenantAccountRole.is_valid_role(invitee_role) or not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
@ -305,6 +264,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
"""Send owner transfer email."""
|
||||
|
||||
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -342,6 +302,7 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
@console_ns.route("/workspaces/current/members/owner-transfer-check")
|
||||
class OwnerTransferCheckApi(Resource):
|
||||
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[VerificationTokenResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -5,7 +5,8 @@ from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@ -85,6 +86,7 @@ register_schema_models(
|
||||
ParserCredentialValidate,
|
||||
ParserPreferredProviderType,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers")
|
||||
@ -177,6 +179,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__])
|
||||
@console_ns.response(204, "Credential deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -197,6 +200,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -271,6 +275,7 @@ class ModelProviderIconApi(Resource):
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
|
||||
class PreferredProviderTypeUpdateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
|
||||
@ -5,7 +5,8 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
@ -126,6 +127,7 @@ register_schema_models(
|
||||
Inner,
|
||||
ParserSwitch,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
register_enum_models(console_ns, ModelType)
|
||||
|
||||
@ -149,6 +151,7 @@ class DefaultModelApi(Resource):
|
||||
return jsonable_encoder({"data": default_model_entity})
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserPostDefault.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -241,6 +244,7 @@ class ModelProviderModelApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@console_ns.response(204, "Model deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -373,6 +377,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__])
|
||||
@console_ns.response(204, "Credential deleted successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -396,6 +401,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserSwitch.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@ -420,6 +426,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
)
|
||||
class ModelProviderModelEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -441,6 +448,7 @@ class ModelProviderModelEnableApi(Resource):
|
||||
)
|
||||
class ModelProviderModelDisableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
import io
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, TypedDict
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model, register_enum_models, register_schema_models
|
||||
from controllers.common.fields import SuccessResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
@ -23,14 +25,6 @@ from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class AutoUpgradeSettingsResponse(TypedDict):
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting
|
||||
upgrade_time_of_day: int
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode
|
||||
exclude_plugins: list[str]
|
||||
include_plugins: list[str]
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
page_size: int = Field(default=256, ge=1, le=256, description="Page size (1-256)")
|
||||
@ -94,8 +88,8 @@ class ParserUninstall(BaseModel):
|
||||
|
||||
|
||||
class ParserPermissionChange(BaseModel):
|
||||
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
|
||||
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
|
||||
install_permission: TenantPluginPermission.InstallPermission
|
||||
debug_permission: TenantPluginPermission.DebugPermission
|
||||
|
||||
|
||||
class ParserDynamicOptions(BaseModel):
|
||||
@ -131,22 +125,13 @@ class PluginAutoUpgradeSettingsPayload(BaseModel):
|
||||
include_plugins: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ParserAutoUpgradeChange(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
class ParserPreferencesChange(BaseModel):
|
||||
permission: PluginPermissionSettingsPayload
|
||||
auto_upgrade: PluginAutoUpgradeSettingsPayload
|
||||
|
||||
|
||||
class ParserAutoUpgradeFetch(BaseModel):
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
|
||||
|
||||
class ParserExcludePlugin(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
plugin_id: str
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory
|
||||
|
||||
|
||||
class ParserReadme(BaseModel):
|
||||
@ -154,6 +139,12 @@ class ParserReadme(BaseModel):
|
||||
language: str = Field(default="en-US")
|
||||
|
||||
|
||||
class PluginDebuggingKeyResponse(ResponseModel):
|
||||
key: str
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
ParserList,
|
||||
@ -173,45 +164,21 @@ register_schema_models(
|
||||
ParserPermissionChange,
|
||||
ParserDynamicOptions,
|
||||
ParserDynamicOptionsWithCredentials,
|
||||
ParserAutoUpgradeChange,
|
||||
ParserAutoUpgradeFetch,
|
||||
ParserPreferencesChange,
|
||||
ParserExcludePlugin,
|
||||
ParserReadme,
|
||||
)
|
||||
register_response_schema_models(console_ns, PluginDebuggingKeyResponse, SuccessResponse)
|
||||
|
||||
register_enum_models(
|
||||
console_ns,
|
||||
TenantPluginPermission.DebugPermission,
|
||||
TenantPluginAutoUpgradeStrategy.PluginCategory,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
TenantPluginPermission.InstallPermission,
|
||||
)
|
||||
|
||||
|
||||
def _default_auto_upgrade_settings(
|
||||
tenant_id: str,
|
||||
category: TenantPluginAutoUpgradeStrategy.PluginCategory,
|
||||
) -> AutoUpgradeSettingsResponse:
|
||||
return {
|
||||
"strategy_setting": PluginAutoUpgradeService.default_strategy_setting_for_category(category),
|
||||
"upgrade_time_of_day": PluginAutoUpgradeService.default_upgrade_time_of_day(tenant_id),
|
||||
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
"exclude_plugins": [],
|
||||
"include_plugins": [],
|
||||
}
|
||||
|
||||
|
||||
def _auto_upgrade_settings_to_dict(strategy: TenantPluginAutoUpgradeStrategy) -> AutoUpgradeSettingsResponse:
|
||||
return {
|
||||
"strategy_setting": strategy.strategy_setting,
|
||||
"upgrade_time_of_day": strategy.upgrade_time_of_day,
|
||||
"upgrade_mode": strategy.upgrade_mode,
|
||||
"exclude_plugins": strategy.exclude_plugins,
|
||||
"include_plugins": strategy.include_plugins,
|
||||
}
|
||||
|
||||
|
||||
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
"""
|
||||
Read the uploaded file and validate its actual size before delegating to the plugin service.
|
||||
@ -228,6 +195,7 @@ def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/debugging-key")
|
||||
class PluginDebuggingKeyApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[PluginDebuggingKeyResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -533,6 +501,7 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||
class PluginDeleteInstallTaskApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -548,6 +517,7 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/delete_all")
|
||||
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -563,6 +533,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||
class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -628,6 +599,7 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
@console_ns.route("/workspaces/current/plugin/uninstall")
|
||||
class PluginUninstallApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserUninstall.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SuccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -646,6 +618,7 @@ class PluginUninstallApi(Resource):
|
||||
@console_ns.route("/workspaces/current/plugin/permission/change")
|
||||
class PluginChangePermissionApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SuccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -659,13 +632,11 @@ class PluginChangePermissionApi(Resource):
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
if not set_permission_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
|
||||
|
||||
return jsonable_encoder({"success": True})
|
||||
return {
|
||||
"success": PluginPermissionService.change_permission(
|
||||
tenant_id, args.install_permission, args.debug_permission
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/permission/fetch")
|
||||
@ -754,9 +725,9 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource):
|
||||
return jsonable_encoder({"options": options})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/change")
|
||||
class PluginChangeAutoUpgradeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserAutoUpgradeChange.__name__])
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/change")
|
||||
class PluginChangePreferencesApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -765,17 +736,38 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
|
||||
args = ParserPreferencesChange.model_validate(console_ns.payload)
|
||||
|
||||
permission = args.permission
|
||||
|
||||
install_permission = permission.install_permission
|
||||
debug_permission = permission.debug_permission
|
||||
|
||||
auto_upgrade = args.auto_upgrade
|
||||
|
||||
strategy_setting = auto_upgrade.strategy_setting
|
||||
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day
|
||||
upgrade_mode = auto_upgrade.upgrade_mode
|
||||
exclude_plugins = auto_upgrade.exclude_plugins
|
||||
include_plugins = auto_upgrade.include_plugins
|
||||
|
||||
# set permission
|
||||
set_permission_result = PluginPermissionService.change_permission(
|
||||
tenant_id,
|
||||
install_permission,
|
||||
debug_permission,
|
||||
)
|
||||
if not set_permission_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
|
||||
|
||||
# set auto upgrade strategy
|
||||
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant_id,
|
||||
auto_upgrade.strategy_setting,
|
||||
auto_upgrade.upgrade_time_of_day,
|
||||
auto_upgrade.upgrade_mode,
|
||||
auto_upgrade.exclude_plugins,
|
||||
auto_upgrade.include_plugins,
|
||||
category=args.category,
|
||||
strategy_setting,
|
||||
upgrade_time_of_day,
|
||||
upgrade_mode,
|
||||
exclude_plugins,
|
||||
include_plugins,
|
||||
)
|
||||
if not set_auto_upgrade_strategy_result:
|
||||
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
|
||||
@ -783,62 +775,6 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
return jsonable_encoder({"success": True})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/fetch")
|
||||
class PluginFetchAutoUpgradeApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(ParserAutoUpgradeFetch))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserAutoUpgradeFetch.model_validate(request.args.to_dict(flat=True))
|
||||
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id, args.category)
|
||||
auto_upgrade_dict = (
|
||||
_auto_upgrade_settings_to_dict(auto_upgrade)
|
||||
if auto_upgrade
|
||||
else _default_auto_upgrade_settings(tenant_id, args.category)
|
||||
)
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"category": args.category,
|
||||
"auto_upgrade": auto_upgrade_dict,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/auto-upgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder(
|
||||
{"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id, args.category)}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserReadme.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
|
||||
return jsonable_encoder(
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/fetch")
|
||||
class PluginFetchPreferencesApi(Resource):
|
||||
@setup_required
|
||||
@ -876,3 +812,32 @@ class PluginFetchPreferencesApi(Resource):
|
||||
}
|
||||
|
||||
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
|
||||
class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# exclude one single plugin
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserExcludePlugin.model_validate(console_ns.payload)
|
||||
|
||||
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)})
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/plugin/readme")
|
||||
class PluginReadmeApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserReadme.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserReadme.model_validate(request.args.to_dict(flat=True))
|
||||
return jsonable_encoder(
|
||||
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)}
|
||||
)
|
||||
|
||||
@ -1,614 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.enterprise import rbac_service as svc
|
||||
|
||||
|
||||
_LEGACY_ROLE_PERMISSION_KEYS: dict[str, list[str]] = {
|
||||
# This is a compatibility projection from the pre-RBAC workspace roles into
|
||||
# the 2.0 permission matrix documented in "权限整理2.0". It intentionally
|
||||
# models the product-facing role surface for the new RBAC UI instead of the
|
||||
# legacy backend's exact hard-authorization checks.
|
||||
"owner": [
|
||||
*svc._LEGACY_WORKSPACE_OWNER_KEYS,
|
||||
*svc._LEGACY_APP_OWNER_KEYS,
|
||||
*svc._LEGACY_DATASET_OWNER_KEYS,
|
||||
],
|
||||
"admin": [
|
||||
*svc._LEGACY_WORKSPACE_ADMIN_KEYS,
|
||||
*svc._LEGACY_APP_ADMIN_KEYS,
|
||||
*svc._LEGACY_DATASET_ADMIN_KEYS,
|
||||
],
|
||||
"editor": [
|
||||
*svc._LEGACY_WORKSPACE_EDITOR_KEYS,
|
||||
*svc._LEGACY_APP_EDITOR_KEYS,
|
||||
*svc._LEGACY_DATASET_EDITOR_KEYS,
|
||||
],
|
||||
"normal": [
|
||||
*svc._LEGACY_WORKSPACE_NORMAL_KEYS,
|
||||
*svc._LEGACY_APP_NORMAL_KEYS,
|
||||
],
|
||||
"dataset_operator": [
|
||||
*svc._LEGACY_WORKSPACE_DATASET_OPERATOR_KEYS,
|
||||
*svc._LEGACY_DATASET_DATASET_OPERATOR_KEYS,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _current_ids() -> tuple[str, str]:
|
||||
"""Return ``(tenant_id, account_id)`` for the authenticated user, or
|
||||
raise a 404 when no tenant is associated with the session.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
if not tenant_id:
|
||||
raise NotFound("Current workspace not found")
|
||||
return tenant_id, user.id
|
||||
|
||||
|
||||
def _payload(model: type[BaseModel]) -> Any:
|
||||
"""Validate the JSON body against ``model`` or raise ``ValidationError``.
|
||||
|
||||
``ValidationError`` bubbles up as HTTP 400 thanks to
|
||||
``controllers/common/helpers.py`` error handling.
|
||||
"""
|
||||
try:
|
||||
return model.model_validate(console_ns.payload or {})
|
||||
except ValidationError as exc:
|
||||
# Re-raise as-is so the upstream error handler renders a 400.
|
||||
raise exc
|
||||
|
||||
|
||||
def _dump(model: BaseModel) -> dict[str, Any]:
|
||||
return model.model_dump(mode="json")
|
||||
|
||||
|
||||
class _PaginationQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
page_number: int | None = Field(default=None, ge=1, validation_alias=AliasChoices("page", "page_number"))
|
||||
results_per_page: int | None = Field(
|
||||
default=None, ge=1, le=100, validation_alias=AliasChoices("limit", "results_per_page")
|
||||
)
|
||||
reverse: bool | None = None
|
||||
|
||||
def to_inner_options(self) -> svc.ListOption:
|
||||
return svc.ListOption.model_validate(self.model_dump())
|
||||
|
||||
|
||||
class _RolesListQuery(_PaginationQuery):
|
||||
include_owner: int = Field(default=0, ge=0, le=1)
|
||||
|
||||
|
||||
def _pagination_options() -> svc.ListOption:
|
||||
return _PaginationQuery.model_validate(request.args.to_dict(flat=True)).to_inner_options()
|
||||
|
||||
|
||||
def _filter_out_owner(paginated: svc.Paginated[svc.RBACRole]) -> svc.Paginated[svc.RBACRole]:
|
||||
filtered = [r for r in paginated.data if r.name not in {"所有者", "owner"}]
|
||||
return svc.Paginated[svc.RBACRole](
|
||||
data=filtered,
|
||||
pagination=paginated.pagination,
|
||||
)
|
||||
|
||||
|
||||
def _legacy_workspace_roles(options: svc.ListOption | None = None) -> svc.Paginated[svc.RBACRole]:
|
||||
"""Return the built-in legacy workspace roles in the RBAC list shape.
|
||||
|
||||
This keeps the new `/rbac/roles` endpoint compatible with the original
|
||||
Dify role model when enterprise RBAC is disabled.
|
||||
"""
|
||||
|
||||
legacy_roles = [
|
||||
svc.RBACRole(
|
||||
id=role_name,
|
||||
tenant_id="",
|
||||
type=svc.RBACRoleType.WORKSPACE.value,
|
||||
category="global_system_default",
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
]
|
||||
|
||||
page_number = options.page_number if options and options.page_number is not None else 1
|
||||
results_per_page = options.results_per_page if options and options.results_per_page is not None else len(legacy_roles)
|
||||
reverse = options.reverse if options and options.reverse is not None else False
|
||||
|
||||
ordered_roles = list(reversed(legacy_roles)) if reverse else legacy_roles
|
||||
start = max(page_number - 1, 0) * results_per_page
|
||||
end = start + results_per_page
|
||||
paged_roles = ordered_roles[start:end]
|
||||
total_count = len(legacy_roles)
|
||||
total_pages = (total_count + results_per_page - 1) // results_per_page if results_per_page > 0 else 0
|
||||
|
||||
return svc.Paginated[svc.RBACRole](
|
||||
data=paged_roles,
|
||||
pagination=svc.Pagination(
|
||||
total_count=total_count,
|
||||
per_page=results_per_page,
|
||||
current_page=page_number,
|
||||
total_pages=total_pages,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission catalogs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
|
||||
class RBACWorkspaceCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.workspace(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/app")
|
||||
class RBACAppCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.app(tenant_id, account_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog/dataset")
|
||||
class RBACDatasetCatalogApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Catalog.dataset(tenant_id, account_id))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Roles.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _RoleUpsertRequest(BaseModel):
|
||||
"""Accepts the payload sent by the Create/Edit Role dialog."""
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
def to_mutation(self) -> svc.RoleMutation:
|
||||
return svc.RoleMutation(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
permission_keys=list(self.permission_keys),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles")
|
||||
class RBACRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
query = _RolesListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
options = query.to_inner_options()
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
result = _legacy_workspace_roles(options)
|
||||
else:
|
||||
result = svc.RBACService.Roles.list(tenant_id, account_id, options=options)
|
||||
if query.include_owner == 0:
|
||||
result = _filter_out_owner(result)
|
||||
|
||||
data = []
|
||||
for role in result.data:
|
||||
if role.name in {"所有者", "owner"}:
|
||||
role.role_tag = "owner"
|
||||
else:
|
||||
role.role_tag = ""
|
||||
data.append(role)
|
||||
result.data = data
|
||||
return _dump(result)
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.create(tenant_id, account_id, request.to_mutation())
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>")
|
||||
class RBACRoleItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.Roles.get(tenant_id, account_id, str(role_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_RoleUpsertRequest)
|
||||
role = svc.RBACService.Roles.update(tenant_id, account_id, str(role_id), request.to_mutation())
|
||||
return _dump(role)
|
||||
|
||||
@login_required
|
||||
def delete(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.Roles.delete(tenant_id, account_id, str(role_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/copy")
|
||||
class RBACRoleCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
role = svc.RBACService.Roles.copy(tenant_id, account_id, str(role_id))
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access policies (tenant-level permission sets).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AccessPolicyCreateRequest(BaseModel):
|
||||
name: str
|
||||
resource_type: svc.RBACResourceType
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
class _AccessPolicyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
permission_keys: list[str] = []
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies")
|
||||
class RBACAccessPoliciesApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
# `resource_type` is exposed as a query argument so the UI can show
|
||||
# only app-scoped or only dataset-scoped permission sets.
|
||||
resource_type = request.args.get("resource_type") or None
|
||||
return _dump(
|
||||
svc.RBACService.AccessPolicies.list(
|
||||
tenant_id,
|
||||
account_id,
|
||||
resource_type=resource_type,
|
||||
options=_pagination_options(),
|
||||
)
|
||||
)
|
||||
|
||||
@login_required
|
||||
def post(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyCreateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.create(
|
||||
tenant_id,
|
||||
account_id,
|
||||
svc.AccessPolicyCreate(
|
||||
name=request.name,
|
||||
resource_type=request.resource_type,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>")
|
||||
class RBACAccessPolicyItemApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AccessPolicies.get(tenant_id, account_id, str(policy_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_AccessPolicyUpdateRequest)
|
||||
policy = svc.RBACService.AccessPolicies.update(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.AccessPolicyUpdate(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
permission_keys=list(request.permission_keys),
|
||||
),
|
||||
)
|
||||
return _dump(policy)
|
||||
|
||||
@login_required
|
||||
def delete(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
svc.RBACService.AccessPolicies.delete(tenant_id, account_id, str(policy_id))
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/access-policies/<uuid:policy_id>/copy")
|
||||
class RBACAccessPolicyCopyApi(Resource):
|
||||
@login_required
|
||||
def post(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
policy = svc.RBACService.AccessPolicies.copy(tenant_id, account_id, str(policy_id))
|
||||
return _dump(policy), 201
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-app access (App Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceBindingsRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
account_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", "account_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bindings(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/my-permissions")
|
||||
class RBACMyPermissionsApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.MyPermissions.get(
|
||||
tenant_id,
|
||||
account_id,
|
||||
app_id=request.args.get("app_id") or None,
|
||||
dataset_id=request.args.get("dataset_id") or None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policy")
|
||||
class RBACAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.AppAccess.matrix(tenant_id, account_id, str(app_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_role_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.list_member_bindings(tenant_id, account_id, str(app_id), str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/apps/<uuid:app_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, app_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.AppAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(app_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-dataset access (Knowledge Base Access Config).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policy")
|
||||
class RBACDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.DatasetAccess.matrix(tenant_id, account_id, str(dataset_id)))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_role_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.replace_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(dataset_id),
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/rbac/datasets/<uuid:dataset_id>/access-policies/<uuid:policy_id>/member-bindings"
|
||||
)
|
||||
class RBACDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, dataset_id, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.DatasetAccess.list_member_bindings(
|
||||
tenant_id, account_id, str(dataset_id), str(policy_id)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace-level access (Settings > Access Rules).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
|
||||
class RBACWorkspaceAppMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.app_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceAppRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceAppBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_app_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceAppMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_app_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policy")
|
||||
class RBACWorkspaceDatasetMatrixApi(Resource):
|
||||
@login_required
|
||||
def get(self):
|
||||
tenant_id, account_id = _current_ids()
|
||||
options = _pagination_options()
|
||||
return _dump(svc.RBACService.WorkspaceAccess.dataset_matrix(tenant_id, account_id, options=options))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/role-bindings")
|
||||
class RBACWorkspaceDatasetRoleBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_role_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/bindings")
|
||||
class RBACWorkspaceDatasetBindingsApi(Resource):
|
||||
@login_required
|
||||
def put(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceBindingsRequest)
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.replace_dataset_bindings(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(policy_id),
|
||||
svc.ReplaceBindings(role_ids=list(request.role_ids), account_ids=list(request.account_ids)),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/datasets/access-policies/<uuid:policy_id>/member-bindings")
|
||||
class RBACWorkspaceDatasetMemberBindingsApi(Resource):
|
||||
@login_required
|
||||
def get(self, policy_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.WorkspaceAccess.list_dataset_member_bindings(tenant_id, account_id, str(policy_id))
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Member ↔ role bindings (Settings > Members > Assign roles).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplaceMemberRolesRequest(BaseModel):
|
||||
role_ids: list[str] = []
|
||||
|
||||
@field_validator("role_ids", mode="before")
|
||||
@classmethod
|
||||
def _coerce_role_ids(cls, value: Any) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/members/<uuid:member_id>/rbac-roles")
|
||||
class RBACMemberRolesApi(Resource):
|
||||
@login_required
|
||||
def get(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(svc.RBACService.MemberRoles.get(tenant_id, account_id, str(member_id)))
|
||||
|
||||
@login_required
|
||||
def put(self, member_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
request = _payload(_ReplaceMemberRolesRequest)
|
||||
return _dump(
|
||||
svc.RBACService.MemberRoles.replace(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(member_id),
|
||||
role_ids=list(request.role_ids),
|
||||
)
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user