mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 20:08:06 +08:00
Compare commits
193 Commits
feat/cli
...
build/samp
| Author | SHA1 | Date | |
|---|---|---|---|
| d70acb217a | |||
| 927a17804b | |||
| 29f34848cd | |||
| 1b0d4637b3 | |||
| 936a09c704 | |||
| 5cc62fd1c9 | |||
| 7bc19d8251 | |||
| e845475408 | |||
| 9a8aa6a0c3 | |||
| 76a7f5f4b9 | |||
| 2ff50514c8 | |||
| 7901ac9a97 | |||
| ecd830083a | |||
| 203b3a9499 | |||
| 9331024d91 | |||
| c6a5de3c18 | |||
| cd3327013a | |||
| cd66559ebf | |||
| 8b77ec7f31 | |||
| bb3de5dd32 | |||
| 1e2d309122 | |||
| a24ec60e51 | |||
| 8fd616d27f | |||
| e5bdc40dce | |||
| 376c43e5ac | |||
| 3ebb449d25 | |||
| 5297ac76ec | |||
| bbed1d4a7c | |||
| c804dbed8c | |||
| 00bf3f83f2 | |||
| 7e6745e105 | |||
| d648ce6888 | |||
| f3c3534e33 | |||
| 8967ff34b3 | |||
| 57539792c1 | |||
| 03e227f8f1 | |||
| 506e1a8bc7 | |||
| f8873ec07b | |||
| b2dacf0718 | |||
| 70eb98d6c5 | |||
| b83f296634 | |||
| 5c68f12bb8 | |||
| 4df7c00859 | |||
| 995c43f3dd | |||
| c0431ec843 | |||
| a0af10abc8 | |||
| 8e2b8168be | |||
| 1f29565673 | |||
| 90fe54ca9e | |||
| b43ebf539d | |||
| 853b859032 | |||
| 8f3e42e9c2 | |||
| 1359c03216 | |||
| 4b7dc17546 | |||
| 81090effe2 | |||
| d92c336394 | |||
| cd9daef564 | |||
| 2876839d7e | |||
| 7ba408eebe | |||
| 3708e3eef1 | |||
| ff5c2c57a1 | |||
| 955c25589d | |||
| 54bde0bdf6 | |||
| 87add9a4f3 | |||
| 574d5865f4 | |||
| 458fab1c48 | |||
| 88196c186e | |||
| dcf21a6a84 | |||
| 91f92c7083 | |||
| 0ca339103f | |||
| 5cf741895f | |||
| 11c52e90f6 | |||
| f01e099729 | |||
| 195ff4711d | |||
| fe2f7a8920 | |||
| 3b1458c08f | |||
| 9f47317032 | |||
| e751ec323e | |||
| f1d72eb5d2 | |||
| 44242d03b4 | |||
| ed7ea68f7d | |||
| afbc30c9ed | |||
| 0e55dcb297 | |||
| 25973c7d77 | |||
| 73ecdd5494 | |||
| 6fafeec415 | |||
| d23cefe005 | |||
| 16d408d908 | |||
| 0536549f73 | |||
| d0956039e7 | |||
| 38eb04dc98 | |||
| d2e1da269c | |||
| 1d3498f659 | |||
| b8dea56198 | |||
| e2becd6746 | |||
| 28a26f2d59 | |||
| 8c7393ef46 | |||
| 5a7a955210 | |||
| 3e4849d765 | |||
| 0c280ef708 | |||
| 282561a861 | |||
| cbb4cc5d76 | |||
| 2d6babeeb4 | |||
| 1065a4840a | |||
| b6aa5a7d69 | |||
| 949f930698 | |||
| 65a08ed7ab | |||
| cc4d6db7c8 | |||
| 6b5d6dacb2 | |||
| 89bf75eba9 | |||
| 3a28868a6c | |||
| 4036515abe | |||
| 6c089cab66 | |||
| 818a71d637 | |||
| 3db107edc9 | |||
| 2677d90860 | |||
| 859756c4f6 | |||
| 295fb6e74a | |||
| 2326fb7a83 | |||
| 2d6eaf69f9 | |||
| 3e826c0000 | |||
| b1b977e284 | |||
| 23648141c9 | |||
| d6dee43c09 | |||
| 7efc887e32 | |||
| 8b346e69d9 | |||
| ef7ff3356d | |||
| 7b5c0b5045 | |||
| f00512dd5d | |||
| e6ef774fd5 | |||
| ce50c6cf1c | |||
| 7002512106 | |||
| c3aebb8403 | |||
| 0baefa6163 | |||
| 7bcedcbaab | |||
| 791fc5819d | |||
| 2d09c4788d | |||
| 9bd5c2f8ec | |||
| 5e336c47fd | |||
| be4c828214 | |||
| ec450eb7f9 | |||
| 48e13f65dc | |||
| 38fc2a6574 | |||
| ed8d3f3e8d | |||
| 0c8dec3315 | |||
| 38e831c1b3 | |||
| 1c5d62d98a | |||
| 6b4736bf78 | |||
| c9503fd818 | |||
| 91a1df96cb | |||
| 5b2c5da945 | |||
| b59ecea346 | |||
| 61c0948136 | |||
| f746c7bdf2 | |||
| 2a3deee385 | |||
| 4b6803ba06 | |||
| 4c908c8f39 | |||
| afec528f51 | |||
| 491061b8f4 | |||
| 8b1533438f | |||
| ba924fc97b | |||
| 712e522220 | |||
| 33eebe8cfc | |||
| 2e1b11bdb2 | |||
| d65a6b4810 | |||
| 44a91e344c | |||
| 0fec9af6a6 | |||
| 5e5113e08e | |||
| 73f9a9e7d6 | |||
| 48d23cd744 | |||
| 0b60bf6ef0 | |||
| 3b24d8d2d1 | |||
| 051ba99cd2 | |||
| dc83e8aa09 | |||
| 77f8f2babb | |||
| 77d6c108e7 | |||
| c2a5962023 | |||
| d583b1b835 | |||
| da00de6688 | |||
| a633387e9b | |||
| df389eba1c | |||
| 5cae61eb5a | |||
| ba8e0681d5 | |||
| de123a8695 | |||
| 21e5962f98 | |||
| db60e649b8 | |||
| e561788809 | |||
| 3cd6ef4464 | |||
| 39dc636b02 | |||
| 4f03b7193e | |||
| 0d921cd21d | |||
| 1a7e46368e | |||
| 8c8ad02a6f |
@ -367,7 +367,7 @@ For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 3. Run: pnpm type-check │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
@ -74,11 +79,37 @@ const invoiceQuery = useQuery({
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) {
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
@ -49,7 +91,7 @@ Use:
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
|
||||
@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
- ✅ **Import real project components** directly (including base components and siblings)
|
||||
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
|
||||
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
@ -325,12 +325,12 @@ For more detailed information, refer to:
|
||||
### Reference Examples in Codebase
|
||||
|
||||
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||
- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
|
||||
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||
|
||||
### Project Configuration
|
||||
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vite.config.ts` - Vite/Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
|
||||
@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||
|
||||
### Integration vs Mocking
|
||||
|
||||
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
|
||||
- [ ] Import real project components instead of mocking
|
||||
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||
- [ ] Prefer integration testing when using single spec file
|
||||
@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||
|
||||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
|
||||
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||
@ -127,7 +127,7 @@ For the current file being tested:
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
- [ ] Run `pnpm type-check`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
|
||||
@ -2,29 +2,27 @@
|
||||
|
||||
## ⚠️ Important: What NOT to Mock
|
||||
|
||||
### DO NOT Mock Base Components
|
||||
### DO NOT Mock Base Components or dify-ui Primitives
|
||||
|
||||
**Never mock components from `@/app/components/base/`** such as:
|
||||
**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
|
||||
|
||||
- `Loading`, `Spinner`
|
||||
- `Button`, `Input`, `Select`
|
||||
- `Tooltip`, `Modal`, `Dropdown`
|
||||
- `Icon`, `Badge`, `Tag`
|
||||
- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
|
||||
- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
|
||||
|
||||
**Why?**
|
||||
|
||||
- Base components will have their own dedicated tests
|
||||
- These components have their own dedicated tests
|
||||
- Mocking them creates false positives (tests pass but real integration fails)
|
||||
- Using real components tests actual integration behavior
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components
|
||||
// ❌ WRONG: Don't mock base components or dify-ui primitives
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => <button>{children}</button> }))
|
||||
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
// ✅ CORRECT: Import and use the real components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
@ -319,7 +317,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
|
||||
### ✅ DO
|
||||
|
||||
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||
1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
|
||||
1. **Use real project components** - Prefer importing over mocking
|
||||
1. **Use real Zustand stores** - Set test state via `store.setState()`
|
||||
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||
@ -330,7 +328,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
|
||||
### ❌ DON'T
|
||||
|
||||
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
|
||||
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
@ -342,7 +340,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
```
|
||||
Need to use a component in test?
|
||||
│
|
||||
├─ Is it from @/app/components/base/*?
|
||||
├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
|
||||
│ └─ YES → Import real component, DO NOT mock
|
||||
│
|
||||
├─ Is it a project component?
|
||||
|
||||
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -6,6 +6,9 @@
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
|
||||
2
.github/actions/setup-web/action.yml
vendored
2
.github/actions/setup-web/action.yml
vendored
@ -4,7 +4,7 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
|
||||
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
@ -6,5 +6,4 @@ web:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
19
.github/workflows/anti-slop.yml
vendored
19
.github/workflows/anti-slop.yml
vendored
@ -1,19 +0,0 @@
|
||||
name: Anti-Slop PR Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited, synchronize]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
failure-add-pr-labels: "needs-revision"
|
||||
14
.github/workflows/api-tests.yml
vendored
14
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
api-unit:
|
||||
name: API Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-unit
|
||||
defaults:
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
api-integration:
|
||||
name: API Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-integration
|
||||
STORAGE_TYPE: opendal
|
||||
@ -84,7 +84,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -105,7 +105,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
api-coverage:
|
||||
name: API Coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
needs:
|
||||
- api-unit
|
||||
- api-integration
|
||||
@ -156,7 +156,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
15
.github/workflows/autofix.yml
vendored
15
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
@ -25,7 +25,7 @@ jobs:
|
||||
- name: Check Docker Compose inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
- name: Check web inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: web-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -43,12 +43,11 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -58,7 +57,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@ -114,7 +113,7 @@ jobs:
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Setup web environment
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: ESLint autofix
|
||||
@ -123,4 +122,4 @@ jobs:
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
|
||||
|
||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,6 +26,9 @@ jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@ -35,28 +38,28 @@ jobs:
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -70,8 +73,8 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -81,16 +84,15 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
@ -108,9 +110,33 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
fork-build-validate:
|
||||
if: github.repository != 'langgenius/dify'
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
create-manifest:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
34
.github/workflows/db-migration-test.yml
vendored
34
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -110,6 +110,28 @@ jobs:
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
|
||||
# to return (container processes started); it does not wait on healthcheck
|
||||
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
|
||||
# wait the migration runs while InnoDB is still initialising and gets
|
||||
# killed with "Lost connection during query". Poll a real SELECT until it
|
||||
# succeeds.
|
||||
- name: Wait for MySQL to accept queries
|
||||
run: |
|
||||
set +e
|
||||
for i in $(seq 1 60); do
|
||||
if docker run --rm --network host mysql:8.0 \
|
||||
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
|
||||
-e 'SELECT 1' >/dev/null 2>&1; then
|
||||
echo "MySQL ready after ${i}s"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL not ready after 60s; dumping container logs:"
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
|
||||
exit 1
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
|
||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
|
||||
43
.github/workflows/docker-build.yml
vendored
43
.github/workflows/docker-build.yml
vendored
@ -14,28 +14,59 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
|
||||
build-docker-fork:
|
||||
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
@ -48,6 +79,4 @@ jobs:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
|
||||
26
.github/workflows/main-ci.yml
vendored
26
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
||||
jobs:
|
||||
pre_job:
|
||||
name: Skip Duplicate Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||
steps:
|
||||
@ -39,7 +39,7 @@ jobs:
|
||||
name: Check Changed Files
|
||||
needs: pre_job
|
||||
if: needs.pre_job.outputs.should_skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
@ -69,7 +69,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@ -83,7 +82,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
@ -141,7 +139,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped API tests
|
||||
run: echo "No API-related changes detected; skipping API tests."
|
||||
@ -154,7 +152,7 @@ jobs:
|
||||
- check-changes
|
||||
- api-tests-run
|
||||
- api-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize API Tests status
|
||||
env:
|
||||
@ -201,7 +199,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web tests
|
||||
run: echo "No web-related changes detected; skipping web tests."
|
||||
@ -214,7 +212,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-tests-run
|
||||
- web-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Tests status
|
||||
env:
|
||||
@ -260,7 +258,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web full-stack e2e
|
||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||
@ -273,7 +271,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-e2e-run
|
||||
- web-e2e-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Full-Stack E2E status
|
||||
env:
|
||||
@ -325,7 +323,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped VDB tests
|
||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||
@ -338,7 +336,7 @@ jobs:
|
||||
- check-changes
|
||||
- vdb-tests-run
|
||||
- vdb-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize VDB Tests status
|
||||
env:
|
||||
@ -384,7 +382,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped DB migration tests
|
||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||
@ -397,7 +395,7 @@ jobs:
|
||||
- check-changes
|
||||
- db-migration-test-run
|
||||
- db-migration-test-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize DB Migration Test status
|
||||
env:
|
||||
|
||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
4
.github/workflows/pyrefly-diff.yml
vendored
4
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
21
.github/workflows/style.yml
vendored
21
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -83,7 +83,6 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
@ -95,7 +94,7 @@ jobs:
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: .eslintcache
|
||||
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
@ -110,6 +109,8 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
@ -124,14 +125,14 @@ jobs:
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: .eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -142,7 +143,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
5
.github/workflows/tool-test-sdks.yaml
vendored
5
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,7 +9,6 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -18,7 +17,7 @@ concurrency:
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@ -30,7 +29,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
|
||||
6
.github/workflows/vdb-tests-full.yml
vendored
6
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -36,7 +36,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
6
.github/workflows/vdb-tests.yml
vendored
6
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -28,7 +28,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -219,6 +219,9 @@ node_modules
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# generated API OpenAPI specs
|
||||
packages/contracts/openapi/
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
@ -237,6 +240,10 @@ scripts/stress-test/reports/
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# vitest browser mode attachments (failure screenshots, traces, etc.)
|
||||
.vitest-attachments/
|
||||
**/__screenshots__/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
@ -30,7 +30,7 @@ The codebase is split into:
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
22
README.md
22
README.md
@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
./dify-compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Seeking help
|
||||
@ -137,20 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
@ -160,7 +148,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
@ -709,22 +714,6 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
|
||||
@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
```
|
||||
uv run dev/generate_swagger_specs.py --output-dir openapi
|
||||
```
|
||||
|
||||
use https://jsontotable.org/openapi-to-typescript to convert to typescript
|
||||
|
||||
@ -159,7 +159,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -204,7 +203,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -33,6 +33,7 @@ from .vector import (
|
||||
old_metadata_migration,
|
||||
vdb_migrate,
|
||||
)
|
||||
from .vector_space import sample_vector_space_usage
|
||||
|
||||
__all__ = [
|
||||
"add_qdrant_index",
|
||||
@ -62,6 +63,7 @@ __all__ = [
|
||||
"reset_encrypt_key_pair",
|
||||
"reset_password",
|
||||
"restore_workflow_runs",
|
||||
"sample_vector_space_usage",
|
||||
"setup_datasource_oauth_client",
|
||||
"setup_system_tool_oauth_client",
|
||||
"setup_system_trigger_oauth_client",
|
||||
|
||||
@ -113,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
# Generate a random password that satisfies the password policy.
|
||||
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
|
||||
for _ in range(100):
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
valid_password(new_password)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
|
||||
return
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
558
api/commands/vector_space.py
Normal file
558
api/commands/vector_space.py
Normal file
@ -0,0 +1,558 @@
|
||||
import csv
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DatasetCollectionBinding,
|
||||
DocumentSegment,
|
||||
DocumentSegmentSummary,
|
||||
SegmentAttachmentBinding,
|
||||
TidbAuthBinding,
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import IndexingStatus, SegmentStatus, SummaryStatus, TidbAuthBindingStatus
|
||||
from models.model import App, AppAnnotationSetting, MessageAnnotation
|
||||
|
||||
COMMON_EMBEDDING_MODEL_DIMS = {
|
||||
# OpenAI
|
||||
"text-embedding-ada-002": 1536,
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
# Cohere
|
||||
"embed-english-v3.0": 1024,
|
||||
"embed-multilingual-v3.0": 1024,
|
||||
"embed-english-light-v3.0": 384,
|
||||
"embed-multilingual-light-v3.0": 384,
|
||||
# Google
|
||||
"embedding-001": 768,
|
||||
"text-embedding-004": 768,
|
||||
# Voyage
|
||||
"voyage-2": 1024,
|
||||
"voyage-3": 1024,
|
||||
"voyage-3-lite": 512,
|
||||
"voyage-large-2": 1536,
|
||||
"voyage-code-2": 1536,
|
||||
# BAAI BGE
|
||||
"bge-small-en": 384,
|
||||
"bge-small-en-v1.5": 384,
|
||||
"bge-small-zh": 512,
|
||||
"bge-small-zh-v1.5": 512,
|
||||
"bge-base-en": 768,
|
||||
"bge-base-en-v1.5": 768,
|
||||
"bge-base-zh": 768,
|
||||
"bge-base-zh-v1.5": 768,
|
||||
"bge-large-en": 1024,
|
||||
"bge-large-en-v1.5": 1024,
|
||||
"bge-large-zh": 1024,
|
||||
"bge-large-zh-v1.5": 1024,
|
||||
"bge-m3": 1024,
|
||||
# E5
|
||||
"multilingual-e5-small": 384,
|
||||
"multilingual-e5-base": 768,
|
||||
"multilingual-e5-large": 1024,
|
||||
"e5-small-v2": 384,
|
||||
"e5-base-v2": 768,
|
||||
"e5-large-v2": 1024,
|
||||
# M3E
|
||||
"m3e-small": 512,
|
||||
"m3e-base": 768,
|
||||
"m3e-large": 1024,
|
||||
# Jina
|
||||
"jina-embeddings-v2-small-en": 512,
|
||||
"jina-embeddings-v2-base-en": 768,
|
||||
"jina-embeddings-v2-base-zh": 768,
|
||||
"jina-embeddings-v3": 1024,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CollectionPointStats:
|
||||
collection_name: str
|
||||
source_type: str
|
||||
source_id: str
|
||||
model_provider: str | None
|
||||
model_name: str | None
|
||||
segment_points: int = 0
|
||||
child_chunk_points: int = 0
|
||||
summary_points: int = 0
|
||||
attachment_points: int = 0
|
||||
annotation_points: int = 0
|
||||
|
||||
@property
|
||||
def total_points(self) -> int:
|
||||
return (
|
||||
self.segment_points
|
||||
+ self.child_chunk_points
|
||||
+ self.summary_points
|
||||
+ self.attachment_points
|
||||
+ self.annotation_points
|
||||
)
|
||||
|
||||
|
||||
def _parse_overheads(value: str) -> list[int]:
|
||||
overheads = []
|
||||
for item in value.split(","):
|
||||
item = item.strip()
|
||||
if not item:
|
||||
continue
|
||||
overheads.append(int(item))
|
||||
if not overheads:
|
||||
raise click.BadParameter("At least one overhead is required.")
|
||||
return overheads
|
||||
|
||||
|
||||
def _normalize_model_name(model_name: str) -> str:
|
||||
return model_name.strip().split("/")[-1]
|
||||
|
||||
|
||||
def _tidb_storage_usage_bytes(binding: TidbAuthBinding, timeout: float) -> int:
|
||||
if not binding.qdrant_endpoint:
|
||||
raise ValueError("qdrant_endpoint is empty")
|
||||
|
||||
endpoint = binding.qdrant_endpoint.rstrip("/")
|
||||
with httpx.Client(timeout=timeout, verify=False) as client:
|
||||
response = client.get(f"{endpoint}/cluster", headers={"api-key": f"{binding.account}:{binding.password}"})
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
storage = data.get("usage", {}).get("storage", {})
|
||||
row_based = int(storage.get("row_based") or 0)
|
||||
columnar = int(storage.get("columnar") or 0)
|
||||
return row_based + columnar
|
||||
|
||||
|
||||
def _extract_vector_size(collection_payload: dict[str, Any]) -> int | None:
|
||||
vectors = (
|
||||
collection_payload.get("result", {})
|
||||
.get("config", {})
|
||||
.get("params", {})
|
||||
.get("vectors")
|
||||
)
|
||||
if isinstance(vectors, dict):
|
||||
size = vectors.get("size")
|
||||
if isinstance(size, int):
|
||||
return size
|
||||
for vector_config in vectors.values():
|
||||
if isinstance(vector_config, dict) and isinstance(vector_config.get("size"), int):
|
||||
return vector_config["size"]
|
||||
return None
|
||||
|
||||
|
||||
def _qdrant_collection_dim(
|
||||
binding: TidbAuthBinding,
|
||||
collection_name: str,
|
||||
timeout: float,
|
||||
dim_cache: dict[str, int | None],
|
||||
) -> int | None:
|
||||
if collection_name in dim_cache:
|
||||
return dim_cache[collection_name]
|
||||
if not binding.qdrant_endpoint:
|
||||
dim_cache[collection_name] = None
|
||||
return None
|
||||
|
||||
endpoint = binding.qdrant_endpoint.rstrip("/")
|
||||
try:
|
||||
with httpx.Client(timeout=timeout, verify=False) as client:
|
||||
response = client.get(
|
||||
f"{endpoint}/collections/{collection_name}",
|
||||
headers={"api-key": f"{binding.account}:{binding.password}"},
|
||||
)
|
||||
if response.status_code == 404:
|
||||
dim_cache[collection_name] = None
|
||||
return None
|
||||
response.raise_for_status()
|
||||
dim = _extract_vector_size(response.json())
|
||||
dim_cache[collection_name] = dim
|
||||
return dim
|
||||
except Exception:
|
||||
dim_cache[collection_name] = None
|
||||
return None
|
||||
|
||||
|
||||
def _dataset_vector_type(dataset: Dataset) -> str | None:
|
||||
if dataset.index_struct_dict:
|
||||
return dataset.index_struct_dict.get("type")
|
||||
return dify_config.VECTOR_STORE
|
||||
|
||||
|
||||
def _dataset_collection_name(dataset: Dataset) -> str:
|
||||
if dataset.index_struct_dict:
|
||||
vector_store = dataset.index_struct_dict.get("vector_store") or {}
|
||||
collection_name = vector_store.get("class_prefix")
|
||||
if collection_name:
|
||||
return collection_name
|
||||
if dataset.collection_binding_id:
|
||||
binding = db.session.get(DatasetCollectionBinding, dataset.collection_binding_id)
|
||||
if binding:
|
||||
return binding.collection_name
|
||||
return Dataset.gen_collection_name_by_id(dataset.id)
|
||||
|
||||
|
||||
def _active_tidb_bindings(tenant_ids: tuple[str, ...], limit: int, offset: int) -> list[TidbAuthBinding]:
|
||||
stmt = (
|
||||
select(TidbAuthBinding)
|
||||
.where(
|
||||
TidbAuthBinding.tenant_id.is_not(None),
|
||||
TidbAuthBinding.active == True,
|
||||
TidbAuthBinding.status == TidbAuthBindingStatus.ACTIVE,
|
||||
)
|
||||
.order_by(TidbAuthBinding.created_at.desc())
|
||||
)
|
||||
if tenant_ids:
|
||||
stmt = stmt.where(TidbAuthBinding.tenant_id.in_(tenant_ids))
|
||||
else:
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
return list(db.session.scalars(stmt).all())
|
||||
|
||||
|
||||
def _completed_document_filter() -> tuple[Any, ...]:
|
||||
return (
|
||||
DatasetDocument.indexing_status == IndexingStatus.COMPLETED,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
|
||||
|
||||
def _completed_segment_filter() -> tuple[Any, ...]:
|
||||
return (
|
||||
DocumentSegment.status == SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.is_not(None),
|
||||
)
|
||||
|
||||
|
||||
def _count_dataset_points(dataset: Dataset) -> CollectionPointStats:
|
||||
segment_points = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id))
|
||||
.join(DatasetDocument, DatasetDocument.id == DocumentSegment.document_id)
|
||||
.where(
|
||||
DocumentSegment.tenant_id == dataset.tenant_id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DatasetDocument.doc_form != IndexStructureType.PARENT_CHILD_INDEX,
|
||||
*_completed_document_filter(),
|
||||
*_completed_segment_filter(),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
child_chunk_points = (
|
||||
db.session.scalar(
|
||||
select(func.count(ChildChunk.id))
|
||||
.join(DatasetDocument, DatasetDocument.id == ChildChunk.document_id)
|
||||
.where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id,
|
||||
ChildChunk.dataset_id == dataset.id,
|
||||
ChildChunk.index_node_id.is_not(None),
|
||||
*_completed_document_filter(),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
summary_points = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegmentSummary.id))
|
||||
.join(DatasetDocument, DatasetDocument.id == DocumentSegmentSummary.document_id)
|
||||
.where(
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
DocumentSegmentSummary.enabled == True,
|
||||
DocumentSegmentSummary.status == SummaryStatus.COMPLETED,
|
||||
DocumentSegmentSummary.summary_index_node_id.is_not(None),
|
||||
*_completed_document_filter(),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
attachment_points = 0
|
||||
if dataset.is_multimodal:
|
||||
attachment_points = (
|
||||
db.session.scalar(
|
||||
select(func.count(sa.distinct(SegmentAttachmentBinding.attachment_id)))
|
||||
.join(DocumentSegment, DocumentSegment.id == SegmentAttachmentBinding.segment_id)
|
||||
.join(DatasetDocument, DatasetDocument.id == SegmentAttachmentBinding.document_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
|
||||
SegmentAttachmentBinding.dataset_id == dataset.id,
|
||||
*_completed_document_filter(),
|
||||
*_completed_segment_filter(),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return CollectionPointStats(
|
||||
collection_name=_dataset_collection_name(dataset),
|
||||
source_type="dataset",
|
||||
source_id=dataset.id,
|
||||
model_provider=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model,
|
||||
segment_points=int(segment_points),
|
||||
child_chunk_points=int(child_chunk_points),
|
||||
summary_points=int(summary_points),
|
||||
attachment_points=int(attachment_points),
|
||||
)
|
||||
|
||||
|
||||
def _dataset_stats_for_tenant(tenant_id: str) -> list[CollectionPointStats]:
|
||||
datasets = db.session.scalars(
|
||||
select(Dataset).where(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY,
|
||||
)
|
||||
).all()
|
||||
|
||||
stats = []
|
||||
for dataset in datasets:
|
||||
if _dataset_vector_type(dataset) != VectorType.TIDB_ON_QDRANT:
|
||||
continue
|
||||
dataset_stats = _count_dataset_points(dataset)
|
||||
if dataset_stats.total_points > 0:
|
||||
stats.append(dataset_stats)
|
||||
return stats
|
||||
|
||||
|
||||
def _annotation_stats_for_tenant(tenant_id: str) -> list[CollectionPointStats]:
|
||||
rows = db.session.execute(
|
||||
select(
|
||||
App.id,
|
||||
DatasetCollectionBinding.provider_name,
|
||||
DatasetCollectionBinding.model_name,
|
||||
DatasetCollectionBinding.collection_name,
|
||||
func.count(MessageAnnotation.id),
|
||||
)
|
||||
.join(AppAnnotationSetting, AppAnnotationSetting.app_id == App.id)
|
||||
.join(DatasetCollectionBinding, DatasetCollectionBinding.id == AppAnnotationSetting.collection_binding_id)
|
||||
.join(MessageAnnotation, MessageAnnotation.app_id == App.id)
|
||||
.where(App.tenant_id == tenant_id)
|
||||
.group_by(
|
||||
App.id,
|
||||
DatasetCollectionBinding.provider_name,
|
||||
DatasetCollectionBinding.model_name,
|
||||
DatasetCollectionBinding.collection_name,
|
||||
)
|
||||
).all()
|
||||
|
||||
return [
|
||||
CollectionPointStats(
|
||||
collection_name=row[3],
|
||||
source_type="annotation",
|
||||
source_id=row[0],
|
||||
model_provider=row[1],
|
||||
model_name=row[2],
|
||||
annotation_points=int(row[4] or 0),
|
||||
)
|
||||
for row in rows
|
||||
if int(row[4] or 0) > 0
|
||||
]
|
||||
|
||||
|
||||
def _resolve_dim(
|
||||
stat: CollectionPointStats,
|
||||
binding: TidbAuthBinding,
|
||||
default_dim: int,
|
||||
fetch_qdrant_dim: bool,
|
||||
timeout: float,
|
||||
dim_cache: dict[str, int | None],
|
||||
) -> tuple[int, str]:
|
||||
if stat.model_provider and stat.model_name:
|
||||
builtin_dim = COMMON_EMBEDDING_MODEL_DIMS.get(_normalize_model_name(stat.model_name))
|
||||
if builtin_dim:
|
||||
return builtin_dim, "builtin_model_map"
|
||||
|
||||
if fetch_qdrant_dim:
|
||||
qdrant_dim = _qdrant_collection_dim(binding, stat.collection_name, timeout, dim_cache)
|
||||
if qdrant_dim:
|
||||
return qdrant_dim, "qdrant"
|
||||
|
||||
return default_dim, "default"
|
||||
|
||||
|
||||
def _mb(value: int | float | Decimal) -> float:
|
||||
return round(float(value) / 1024 / 1024, 4)
|
||||
|
||||
|
||||
def _log(message: str, quiet: bool) -> None:
|
||||
if not quiet:
|
||||
click.echo(message, err=True)
|
||||
|
||||
|
||||
@click.command(
|
||||
"sample-vector-space-usage",
|
||||
help="Sample TiDB vector storage usage and compare it with local formula estimates.",
|
||||
)
|
||||
@click.option("--tenant-id", multiple=True, help="Tenant ID to sample. Can be repeated.")
|
||||
@click.option("--limit", default=20, show_default=True, help="Number of active TiDB tenants to sample.")
|
||||
@click.option("--offset", default=0, show_default=True, help="Offset when sampling active TiDB tenants.")
|
||||
@click.option("--default-dim", default=3072, show_default=True, help="Fallback embedding dimension.")
|
||||
@click.option(
|
||||
"--overheads",
|
||||
default="3584,5120,8192",
|
||||
show_default=True,
|
||||
help="Comma-separated per-point overhead bytes to compare.",
|
||||
)
|
||||
@click.option("--fetch-qdrant-dim/--no-fetch-qdrant-dim", default=True, show_default=True)
|
||||
@click.option("--include-annotations/--exclude-annotations", default=True, show_default=True)
|
||||
@click.option("--timeout", default=10.0, show_default=True, help="HTTP timeout for TiDB/Qdrant calls.")
|
||||
@click.option("--output", type=click.Path(dir_okay=False, path_type=Path), help="CSV output path. Defaults to stdout.")
|
||||
@click.option("--quiet", is_flag=True, help="Suppress progress logs. CSV output is unaffected.")
|
||||
def sample_vector_space_usage(
|
||||
tenant_id: tuple[str, ...],
|
||||
limit: int,
|
||||
offset: int,
|
||||
default_dim: int,
|
||||
overheads: str,
|
||||
fetch_qdrant_dim: bool,
|
||||
include_annotations: bool,
|
||||
timeout: float,
|
||||
output: Path | None,
|
||||
quiet: bool,
|
||||
):
|
||||
overhead_values = _parse_overheads(overheads)
|
||||
bindings = _active_tidb_bindings(tenant_id, limit, offset)
|
||||
sample_scope = f" for tenant_id={','.join(tenant_id)}" if tenant_id else f" with limit={limit}, offset={offset}"
|
||||
_log(
|
||||
f"Sampling {len(bindings)} active TiDB binding(s){sample_scope}.",
|
||||
quiet,
|
||||
)
|
||||
if not bindings:
|
||||
_log("No active TiDB bindings found. Nothing to sample.", quiet)
|
||||
|
||||
fieldnames = [
|
||||
"tenant_id",
|
||||
"cluster_id",
|
||||
"tidb_actual_mb",
|
||||
"total_points",
|
||||
"segment_points",
|
||||
"child_chunk_points",
|
||||
"summary_points",
|
||||
"attachment_points",
|
||||
"annotation_points",
|
||||
"collection_count",
|
||||
"dim_sources",
|
||||
"dims",
|
||||
"errors",
|
||||
]
|
||||
for overhead in overhead_values:
|
||||
fieldnames.extend(
|
||||
[
|
||||
f"estimated_mb_o{overhead}",
|
||||
f"diff_mb_o{overhead}",
|
||||
f"ratio_o{overhead}",
|
||||
]
|
||||
)
|
||||
|
||||
output_file = output.open("w", newline="") if output else None
|
||||
try:
|
||||
writer = csv.DictWriter(output_file or click.get_text_stream("stdout"), fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for index, binding in enumerate(bindings, start=1):
|
||||
assert binding.tenant_id is not None
|
||||
tenant = binding.tenant_id
|
||||
errors = []
|
||||
dim_cache: dict[str, int | None] = {}
|
||||
_log(f"[{index}/{len(bindings)}] tenant={tenant} cluster={binding.cluster_id}: fetching TiDB usage", quiet)
|
||||
|
||||
try:
|
||||
actual_bytes = _tidb_storage_usage_bytes(binding, timeout)
|
||||
_log(
|
||||
f"[{index}/{len(bindings)}] tenant={tenant}: TiDB actual={_mb(actual_bytes)} MB",
|
||||
quiet,
|
||||
)
|
||||
except Exception as exc:
|
||||
actual_bytes = 0
|
||||
errors.append(f"tidb_usage:{exc.__class__.__name__}:{exc}")
|
||||
_log(
|
||||
f"[{index}/{len(bindings)}] tenant={tenant}: failed to fetch TiDB usage: "
|
||||
f"{exc.__class__.__name__}: {exc}",
|
||||
quiet,
|
||||
)
|
||||
|
||||
_log(f"[{index}/{len(bindings)}] tenant={tenant}: counting local vector points", quiet)
|
||||
collection_stats = _dataset_stats_for_tenant(tenant)
|
||||
if include_annotations:
|
||||
collection_stats.extend(_annotation_stats_for_tenant(tenant))
|
||||
|
||||
total_points = 0
|
||||
segment_points = 0
|
||||
child_chunk_points = 0
|
||||
summary_points = 0
|
||||
attachment_points = 0
|
||||
annotation_points = 0
|
||||
dim_sources: dict[str, int] = {}
|
||||
dims: dict[str, int] = {}
|
||||
estimated_by_overhead = dict.fromkeys(overhead_values, 0)
|
||||
|
||||
for stat in collection_stats:
|
||||
dim, dim_source = _resolve_dim(
|
||||
stat,
|
||||
binding,
|
||||
default_dim,
|
||||
fetch_qdrant_dim,
|
||||
timeout,
|
||||
dim_cache,
|
||||
)
|
||||
dim_sources[dim_source] = dim_sources.get(dim_source, 0) + 1
|
||||
dims[str(dim)] = dims.get(str(dim), 0) + stat.total_points
|
||||
|
||||
total_points += stat.total_points
|
||||
segment_points += stat.segment_points
|
||||
child_chunk_points += stat.child_chunk_points
|
||||
summary_points += stat.summary_points
|
||||
attachment_points += stat.attachment_points
|
||||
annotation_points += stat.annotation_points
|
||||
|
||||
for overhead in overhead_values:
|
||||
estimated_by_overhead[overhead] += stat.total_points * (dim * 4 + overhead)
|
||||
|
||||
_log(
|
||||
f"[{index}/{len(bindings)}] tenant={tenant}: points={total_points}, "
|
||||
f"collections={len(collection_stats)}, dim_sources={json.dumps(dim_sources, sort_keys=True)}",
|
||||
quiet,
|
||||
)
|
||||
|
||||
row: dict[str, Any] = {
|
||||
"tenant_id": tenant,
|
||||
"cluster_id": binding.cluster_id,
|
||||
"tidb_actual_mb": _mb(actual_bytes),
|
||||
"total_points": total_points,
|
||||
"segment_points": segment_points,
|
||||
"child_chunk_points": child_chunk_points,
|
||||
"summary_points": summary_points,
|
||||
"attachment_points": attachment_points,
|
||||
"annotation_points": annotation_points,
|
||||
"collection_count": len(collection_stats),
|
||||
"dim_sources": json.dumps(dim_sources, sort_keys=True),
|
||||
"dims": json.dumps(dims, sort_keys=True),
|
||||
"errors": ";".join(errors),
|
||||
}
|
||||
|
||||
for overhead, estimated_bytes in estimated_by_overhead.items():
|
||||
diff_bytes = estimated_bytes - actual_bytes
|
||||
ratio = round(estimated_bytes / actual_bytes, 6) if actual_bytes > 0 else ""
|
||||
row[f"estimated_mb_o{overhead}"] = _mb(estimated_bytes)
|
||||
row[f"diff_mb_o{overhead}"] = _mb(diff_bytes)
|
||||
row[f"ratio_o{overhead}"] = ratio
|
||||
|
||||
writer.writerow(row)
|
||||
_log(f"[{index}/{len(bindings)}] tenant={tenant}: row written", quiet)
|
||||
finally:
|
||||
if output_file:
|
||||
output_file.close()
|
||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -499,35 +520,6 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -903,17 +895,6 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1188,14 +1169,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
@ -1427,6 +1400,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
"name": "Website Generator"
|
||||
},
|
||||
"app_id": "b53545b1-79ea-4da3-b31a-c39391c6f041",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -35,7 +35,7 @@
|
||||
"name": "Investment Analysis Report Copilot"
|
||||
},
|
||||
"app_id": "a23b57fa-85da-49c0-a571-3aff375976c1",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n",
|
||||
"is_listed": true,
|
||||
@ -51,7 +51,7 @@
|
||||
"name": "Workflow Planning Assistant "
|
||||
},
|
||||
"app_id": "f3303a7d-a81c-404e-b401-1f8711c998c1",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "An assistant that helps you plan and select the right node for a workflow (V0.6.0). ",
|
||||
"is_listed": true,
|
||||
@ -67,7 +67,7 @@
|
||||
"name": "Automated Email Reply "
|
||||
},
|
||||
"app_id": "e9d92058-7d20-4904-892f-75d90bef7587",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Reply emails using Gmail API. It will automatically retrieve email in your inbox and create a response in Gmail. \nConfigure your Gmail API in Google Cloud Console. ",
|
||||
"is_listed": true,
|
||||
@ -83,7 +83,7 @@
|
||||
"name": "Book Translation "
|
||||
},
|
||||
"app_id": "98b87f88-bd22-4d86-8b74-86beba5e0ed4",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A workflow designed to translate a full book up to 15000 tokens per run. Uses Code node to separate text into chunks and Iteration to translate each chunk. ",
|
||||
"is_listed": true,
|
||||
@ -99,7 +99,7 @@
|
||||
"name": "Python bug fixer"
|
||||
},
|
||||
"app_id": "cae337e6-aec5-4c7b-beca-d6f1a808bd5e",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -115,7 +115,7 @@
|
||||
"name": "Code Interpreter"
|
||||
},
|
||||
"app_id": "d077d587-b072-4f2c-b631-69ed1e7cdc0f",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Code interpreter, clarifying the syntax and semantics of the code.",
|
||||
"is_listed": true,
|
||||
@ -131,7 +131,7 @@
|
||||
"name": "SVG Logo Design "
|
||||
},
|
||||
"app_id": "73fbb5f1-c15d-4d74-9cc8-46d9db9b2cca",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL·E 3. ",
|
||||
"is_listed": true,
|
||||
@ -147,7 +147,7 @@
|
||||
"name": "Long Story Generator (Iteration) "
|
||||
},
|
||||
"app_id": "5efb98d7-176b-419c-b6ef-50767391ab62",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A workflow demonstrating how to use Iteration node to generate long article that is longer than the context length of LLMs. ",
|
||||
"is_listed": true,
|
||||
@ -163,7 +163,7 @@
|
||||
"name": "Text Summarization Workflow"
|
||||
},
|
||||
"app_id": "f00c4531-6551-45ee-808f-1d7903099515",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Based on users' choice, retrieve external knowledge to more accurately summarize articles.",
|
||||
"is_listed": true,
|
||||
@ -179,7 +179,7 @@
|
||||
"name": "YouTube Channel Data Analysis"
|
||||
},
|
||||
"app_id": "be591209-2ca8-410f-8f3b-ca0e530dd638",
|
||||
"category": "Agent",
|
||||
"categories": ["Agent"],
|
||||
"copyright": "Dify.AI",
|
||||
"description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ",
|
||||
"is_listed": true,
|
||||
@ -195,7 +195,7 @@
|
||||
"name": "Article Grading Bot"
|
||||
},
|
||||
"app_id": "a747f7b4-c48b-40d6-b313-5e628232c05f",
|
||||
"category": "Writing",
|
||||
"categories": ["Writing"],
|
||||
"copyright": null,
|
||||
"description": "Assess the quality of articles and text based on user defined criteria. ",
|
||||
"is_listed": true,
|
||||
@ -211,7 +211,7 @@
|
||||
"name": "SEO Blog Generator"
|
||||
},
|
||||
"app_id": "18f3bd03-524d-4d7a-8374-b30dbe7c69d5",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Workflow for retrieving information from the internet, followed by segmented generation of SEO blogs.",
|
||||
"is_listed": true,
|
||||
@ -227,7 +227,7 @@
|
||||
"name": "SQL Creator"
|
||||
},
|
||||
"app_id": "050ef42e-3e0c-40c1-a6b6-a64f2c49d744",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.",
|
||||
"is_listed": true,
|
||||
@ -243,7 +243,7 @@
|
||||
"name": "Sentiment Analysis "
|
||||
},
|
||||
"app_id": "f06bf86b-d50c-4895-a942-35112dbe4189",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Batch sentiment analysis of text, followed by JSON output of sentiment classification along with scores.",
|
||||
"is_listed": true,
|
||||
@ -259,7 +259,7 @@
|
||||
"name": "Strategic Consulting Expert"
|
||||
},
|
||||
"app_id": "7e8ca1ae-02f2-4b5f-979e-62d19133bee2",
|
||||
"category": "Assistant",
|
||||
"categories": ["Assistant"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "I can answer your questions related to strategic marketing.",
|
||||
"is_listed": true,
|
||||
@ -275,7 +275,7 @@
|
||||
"name": "Code Converter"
|
||||
},
|
||||
"app_id": "4006c4b2-0735-4f37-8dbb-fb1a8c5bd87a",
|
||||
"category": "Programming",
|
||||
"categories": ["Programming"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "This is an application that provides the ability to convert code snippets in multiple programming languages. You can input the code you wish to convert, select the target programming language, and get the desired output.",
|
||||
"is_listed": true,
|
||||
@ -291,7 +291,7 @@
|
||||
"name": "Question Classifier + Knowledge + Chatbot "
|
||||
},
|
||||
"app_id": "d9f6b733-e35d-4a40-9f38-ca7bbfa009f7",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, a chatbot capable of identifying intents alongside with a knowledge base.",
|
||||
"is_listed": true,
|
||||
@ -307,7 +307,7 @@
|
||||
"name": "AI Front-end interviewer"
|
||||
},
|
||||
"app_id": "127efead-8944-4e20-ba9d-12402eb345e0",
|
||||
"category": "HR",
|
||||
"categories": ["HR"],
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.",
|
||||
"is_listed": true,
|
||||
@ -323,7 +323,7 @@
|
||||
"name": "Knowledge Retrieval + Chatbot "
|
||||
},
|
||||
"app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, A chatbot with a knowledge base. ",
|
||||
"is_listed": true,
|
||||
@ -339,7 +339,7 @@
|
||||
"name": "Email Assistant Workflow "
|
||||
},
|
||||
"app_id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "A multifunctional email assistant capable of summarizing, replying, composing, proofreading, and checking grammar.",
|
||||
"is_listed": true,
|
||||
@ -355,7 +355,7 @@
|
||||
"name": "Customer Review Analysis Workflow "
|
||||
},
|
||||
"app_id": "9c0cd31f-4b62-4005-adf5-e3888d08654a",
|
||||
"category": "Workflow",
|
||||
"categories": ["Workflow"],
|
||||
"copyright": null,
|
||||
"description": "Utilize LLM (Large Language Models) to classify customer reviews and forward them to the internal system.",
|
||||
"is_listed": true,
|
||||
|
||||
@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
# Decode percent-encoded characters in the path segment
|
||||
filename = urllib.parse.unquote(os.path.basename(url_path))
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
|
||||
6
api/controllers/common/human_input.py
Normal file
6
api/controllers/common/human_input.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
@ -8,6 +9,7 @@ from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -66,22 +69,19 @@ class AppListQuery(BaseModel):
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported tag_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
@ -91,6 +91,26 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -455,7 +475,7 @@ class AppListApi(Resource):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
@ -692,6 +712,32 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -60,7 +60,8 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -158,8 +159,13 @@ class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
class WorkflowOnlineUsersPayload(BaseModel):
|
||||
app_ids: list[str] = Field(default_factory=list, description="App IDs")
|
||||
|
||||
@field_validator("app_ids")
|
||||
@classmethod
|
||||
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
@ -186,7 +192,7 @@ reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -1384,19 +1390,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
app_ids = args.app_ids
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
@ -1404,13 +1410,24 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
users_json_by_app_id: dict[str, Any] = {}
|
||||
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
|
||||
app_id_batch = ordered_accessible_app_ids[
|
||||
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
|
||||
]
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
for app_id in app_id_batch:
|
||||
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users_json_batch = pipe.execute()
|
||||
for app_id, users_json in zip(app_id_batch, users_json_batch):
|
||||
users_json_by_app_id[app_id] = users_json
|
||||
|
||||
results = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
for app_id in ordered_accessible_app_ids:
|
||||
users_json = users_json_by_app_id.get(app_id, {})
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
|
||||
@ -75,14 +75,15 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
match value:
|
||||
case FileSegment():
|
||||
return value.value.model_dump()
|
||||
case ArrayFileSegment():
|
||||
return [i.model_dump() for i in value.value]
|
||||
case SegmentGroup():
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
case _:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
@ -38,6 +38,48 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
|
||||
if normalized_record.get("child_chunks") is None:
|
||||
normalized_record["child_chunks"] = []
|
||||
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
@ -75,7 +117,12 @@ class DatasetsHitTestingBase:
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -52,7 +52,7 @@ class RecommendedAppResponse(ResponseModel):
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
category: str | None = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -56,6 +51,11 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_recipient_type(form: Form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
|
||||
raise NotFoundError("form not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -99,10 +99,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
@ -32,7 +32,7 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
@ -152,41 +152,68 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
class TagBindingRemoveApi(Resource):
|
||||
"""Batch resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("remove_tag_bindings")
|
||||
@console_ns.doc(description="Remove one or more tag bindings from a target.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _remove_tag_bindings()
|
||||
|
||||
@ -8,6 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -45,6 +46,8 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@ -322,9 +325,24 @@ class AccountAvatarApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
avatar = args.avatar
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.tenant_id != current_tenant_id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@ -595,13 +613,25 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
# Default to the initial phase; any legacy/unexpected client input is
|
||||
# coerced back to `old_email` so we never trust the caller to declare
|
||||
# later phases without a verified predecessor token.
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# The token used to request a new-email code must come from the
|
||||
# old-email verification step. This prevents the bypass described
|
||||
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@ -620,7 +650,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -655,12 +685,31 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Only advance tokens that were minted by the matching send-code step;
|
||||
# refuse tokens that have already progressed or lack a phase marker so
|
||||
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
|
||||
# is strictly enforced.
|
||||
phase_transitions = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
# Refresh token data by generating a new token that carries the
|
||||
# upgraded phase so later steps can check it.
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@ -690,13 +739,29 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
# Only tokens that completed both verification phases may be used to
|
||||
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
|
||||
# the initial send-code step could be replayed directly here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Bind the new email to the token that was mailed and verified, so a
|
||||
# verified token cannot be reused with a different `new_email` value.
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
# Revoke only after all checks pass so failed attempts don't burn a
|
||||
# legitimately verified token.
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@ -1,3 +1,11 @@
|
||||
"""Console workspace endpoint controllers.
|
||||
|
||||
This module exposes workspace-scoped plugin endpoint management APIs. The
|
||||
canonical write routes follow resource-oriented paths, while the historical
|
||||
verb-based aliases stay available as deprecated resources so OpenAPI metadata
|
||||
marks only the legacy paths as deprecated.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
class EndpointUpdatePayload(BaseModel):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LegacyEndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
@ -76,6 +89,7 @@ register_schema_models(
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
LegacyEndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
@ -88,8 +102,60 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints")
|
||||
class EndpointCollectionApi(Resource):
|
||||
"""Canonical collection resource for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@ -104,22 +170,33 @@ class EndpointCreateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
return _create_endpoint()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class DeprecatedEndpointCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.route("/workspaces/current/endpoints/<string:id>")
|
||||
class EndpointItemApi(Resource):
|
||||
"""Canonical item resource for endpoint updates and deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class DeprecatedEndpointDeleteApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for deleting a plugin endpoint. "
|
||||
"Use DELETE /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
class DeprecatedEndpointUpdateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint updates."""
|
||||
|
||||
@console_ns.doc("update_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating a plugin endpoint. "
|
||||
"Use PATCH /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
|
||||
@ -876,10 +876,10 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id, provider=provider, id=payload.id
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
from . import (
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
@ -1,33 +0,0 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
|
||||
|
||||
def emit_app_run(*, app_id: str, tenant_id: str, caller_kind: str, mode: str) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
},
|
||||
)
|
||||
@ -1,143 +0,0 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
@ -1,112 +0,0 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Server-side cap on `limit` query param for any /openapi/v1/* list endpoint.
|
||||
# Sibling endpoints (`/apps`, `/account/sessions`, future routes) all clamp to
|
||||
# this; do not introduce per-endpoint caps without raising the constant.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
tags: list[dict[str, str]] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[dict[str, str]] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
@ -1,236 +0,0 @@
|
||||
"""User-scoped account endpoints. /account is the bearer-authed
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import and_, select, update
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT, PaginationEnvelope
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email,
|
||||
"subject_issuer": ctx.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
}
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email or (account.email if account else None),
|
||||
"account": _account_payload(account) if account else None,
|
||||
"workspaces": [_workspace_payload(m) for m in memberships],
|
||||
"default_workspace_id": default_ws_id,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
_revoke_token_by_id(str(ctx.token_id))
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = db.session.execute(
|
||||
select(
|
||||
OAuthAccessToken.id,
|
||||
OAuthAccessToken.prefix,
|
||||
OAuthAccessToken.client_id,
|
||||
OAuthAccessToken.device_label,
|
||||
OAuthAccessToken.created_at,
|
||||
OAuthAccessToken.last_used_at,
|
||||
OAuthAccessToken.expires_at,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
*_subject_match(ctx),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
OAuthAccessToken.token_hash.is_not(None),
|
||||
OAuthAccessToken.expires_at > now,
|
||||
)
|
||||
)
|
||||
.order_by(OAuthAccessToken.created_at.desc())
|
||||
).all()
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"prefix": r.prefix,
|
||||
"client_id": r.client_id,
|
||||
"device_label": r.device_label,
|
||||
"created_at": _iso(r.created_at),
|
||||
"last_used_at": _iso(r.last_used_at),
|
||||
"expires_at": _iso(r.expires_at),
|
||||
}
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if owns is None:
|
||||
raise NotFound("session not found")
|
||||
|
||||
_revoke_token_by_id(session_id)
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
def _subject_match(ctx: AuthContext) -> tuple:
|
||||
"""Where-clauses that scope a query to the bearer's subject. Works
|
||||
for both account (account_id) and external_sso (email + issuer).
|
||||
"""
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
return (
|
||||
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||
OAuthAccessToken.account_id.is_(None),
|
||||
)
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _revoke_token_by_id(token_id: str) -> None:
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> dict:
|
||||
join, tenant = row
|
||||
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
|
||||
|
||||
|
||||
def _account_payload(account) -> dict:
|
||||
return {"id": str(account.id), "email": account.email, "name": account.name}
|
||||
@ -1,198 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import (
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | UUID | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
|
||||
if isinstance(response, tuple):
|
||||
response = response[0]
|
||||
if not isinstance(response, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return response
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
try:
|
||||
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
return blocking_body, 200
|
||||
@ -1,315 +0,0 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → sets `g.auth_ctx` before `require_scope` reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
raise NotFound("app not found")
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""`mode` is a closed enum — unknown values 422 instead of silently-empty data."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return PaginationEnvelope[AppListRow].build(page=1, limit=0, total=0, items=[]).model_dump(mode="json"), 200
|
||||
|
||||
try:
|
||||
query = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
PaginationEnvelope[AppListRow]
|
||||
.build(page=query.page, limit=query.limit, total=0, items=[])
|
||||
.model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
if parsed_uuid is not None:
|
||||
app = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = PaginationEnvelope[AppListRow].build(page=1, limit=1, total=1, items=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"mode": query.mode.value if query.mode else "",
|
||||
"name": query.name,
|
||||
"status": "normal",
|
||||
}
|
||||
if tag_ids:
|
||||
args["tag_ids"] = tag_ids
|
||||
|
||||
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, args)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[{"name": t.name} for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=int(pagination.total), items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,101 +0,0 @@
|
||||
"""GET /openapi/v1/apps/permitted — external-subject app discovery (EE only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_EXT_SSO,
|
||||
Scope,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
|
||||
class AppPermittedListQuery(BaseModel):
|
||||
"""Strict (`extra='forbid'`) — rejects `workspace_id`/`tag`/etc. that are valid on /apps but not here."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/permitted")
|
||||
class AppPermittedListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED),
|
||||
validate_bearer(accept=ACCEPT_USER_EXT_SSO),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
def get(self):
|
||||
try:
|
||||
query = AppPermittedListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(sa.select(App).where(App.id.in_(page_result.app_ids))).scalars().all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
|
||||
# total/has_more reflect the EE-side allow-list; len(items) may be < limit when local rows are dropped.
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,3 +0,0 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
@ -1,43 +0,0 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
@ -1,46 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: Scope
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
@ -1,41 +0,0 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
@ -1,131 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
_extract_bearer, # type: ignore[attr-defined]
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
@ -1,115 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL via the workspace-auth inner API.
|
||||
|
||||
Used when webapp-auth is enabled (EE deployment). The inner-API
|
||||
allowlist is the source of truth.
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_email is None or ctx.app is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=ctx.subject_email,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
@ -1,9 +0,0 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
@ -1,392 +0,0 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
PREFIX_OAUTH_ACCOUNT,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Request / query schemas
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_ACCOUNT,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0]["id"]
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": {"id": str(account.id), "email": account.email, "name": account.name},
|
||||
"workspaces": workspaces,
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
@ -1,287 +0,0 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from controllers.openapi import bp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
@ -1,89 +0,0 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
return {"workspaces": []}, 200
|
||||
|
||||
rows = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == str(ctx.account_id))
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = g.auth_ctx
|
||||
# External SSO + missing account → never a member of anything; 404.
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
Tenant.id == workspace_id,
|
||||
TenantAccountJoin.account_id == str(ctx.account_id),
|
||||
)
|
||||
).first()
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
}
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
}
|
||||
@ -23,9 +23,11 @@ from .app import (
|
||||
conversation,
|
||||
file,
|
||||
file_preview,
|
||||
human_input_form,
|
||||
message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
from .dataset import (
|
||||
dataset,
|
||||
@ -50,6 +52,7 @@ __all__ = [
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"message",
|
||||
"metadata",
|
||||
@ -58,6 +61,7 @@ __all__ = [
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
|
||||
137
api/controllers/service_api/app/human_input_form.py
Normal file
137
api/controllers/service_api/app/human_input_form.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""
|
||||
Service API human input form endpoints.
|
||||
|
||||
This module exposes app-token authenticated APIs for fetching and submitting
|
||||
paused human input forms in workflow/chatflow runs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||
# Keep app-token callers scoped to the public web-form surface; internal HITL
|
||||
# routes must continue to flow through console-only authentication.
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||
class WorkflowHumanInputFormApi(Resource):
|
||||
@service_api_ns.doc("get_human_input_form")
|
||||
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, form_token: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form submitted successfully",
|
||||
400: "Bad request - invalid submission data",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, form_token: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
142
api/controllers/service_api/app/workflow_events.py
Normal file
142
api/controllers/service_api/app/workflow_events.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Service API workflow resume event stream endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@service_api_ns.route("/workflow/<string:task_id>/events")
|
||||
class WorkflowEventsApi(Resource):
|
||||
"""Service API for getting workflow execution events after resume."""
|
||||
|
||||
@service_api_ns.doc("get_workflow_events")
|
||||
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "SSE event stream",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow run not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
@ -2,7 +2,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str
|
||||
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
|
||||
|
||||
tag_ids: list[str] = Field(default_factory=list)
|
||||
tag_id: str | None = None
|
||||
target_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_legacy_tag_id(cls, data: object) -> object:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("tag_ids") and data.get("tag_id"):
|
||||
return {**data, "tag_ids": [data["tag_id"]]}
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tag_ids(self) -> "TagUnbindingPayload":
|
||||
if not self.tag_ids:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc("unbind_dataset_tags")
|
||||
@service_api_ns.doc(description="Unbind tags from a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tag unbound successfully",
|
||||
204: "Tags unbound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
"""Service API endpoints for dataset document management.
|
||||
|
||||
The canonical Service API paths use hyphenated route segments. Legacy underscore
|
||||
aliases remain registered for backward compatibility, but they must stay marked
|
||||
deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
@ -117,12 +125,137 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Create a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id_str,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
"""Resource for the canonical text document creation route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@ -138,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document creation."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a new document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-text instead."
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text through the deprecated underscore alias."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
"""Resource for the canonical text document update route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@ -229,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document updates."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text through the deprecated underscore alias."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -400,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -419,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -808,6 +886,22 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
@ -20,7 +22,11 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
Validate and set defaults for suggested questions feature.
|
||||
|
||||
Optional fields:
|
||||
- prompt: custom instruction prompt.
|
||||
- model: provider/model configuration for suggested question generation.
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
@ -39,4 +45,27 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
prompt = config["suggested_questions_after_answer"].get("prompt")
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
raise ValueError("prompt in suggested_questions_after_answer must be of string type")
|
||||
if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"prompt in suggested_questions_after_answer must be less than or equal to "
|
||||
f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters"
|
||||
)
|
||||
|
||||
if "model" in config["suggested_questions_after_answer"]:
|
||||
model_config = config["suggested_questions_after_answer"]["model"]
|
||||
if not isinstance(model_config, dict):
|
||||
raise ValueError("model in suggested_questions_after_answer must be of object type")
|
||||
|
||||
if "provider" not in model_config or not isinstance(model_config["provider"], str):
|
||||
raise ValueError("provider in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "name" not in model_config or not isinstance(model_config["name"], str):
|
||||
raise ValueError("name in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict):
|
||||
raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type")
|
||||
|
||||
return config, ["suggested_questions_after_answer"]
|
||||
|
||||
@ -34,7 +34,11 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -655,7 +659,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
@ -12,22 +12,40 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
"workflow_run_id": blocking_response.data.workflow_run_id,
|
||||
"data": paused_data,
|
||||
}
|
||||
|
||||
response = {
|
||||
"event": "message",
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -41,7 +59,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -50,7 +70,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -53,14 +53,18 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -210,7 +214,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -226,14 +236,39 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
@ -254,8 +289,41 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@ -11,8 +13,10 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
@classmethod
|
||||
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
|
||||
return cast(TBlockingResponse, response)
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
@ -20,7 +24,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -29,7 +33,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -39,12 +43,12 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -106,13 +110,13 @@ class AppGenerateResponseConverter(ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -126,7 +130,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, Any] | None = None
|
||||
data: dict[str, JsonValue] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,11 +14,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -52,6 +52,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -336,7 +337,26 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
},
|
||||
)
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -12,17 +14,15 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
response: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
@ -26,6 +27,7 @@ class MessageGenerator:
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
@ -33,4 +35,5 @@ class MessageGenerator:
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
terminal_events=terminal_events,
|
||||
)
|
||||
|
||||
@ -13,11 +13,9 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -26,7 +24,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -27,7 +27,11 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
@ -627,7 +631,11 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
if terminal_events is None:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user