mirror of
https://github.com/langgenius/dify.git
synced 2026-06-29 02:18:12 +08:00
Compare commits
1 Commits
laipz8200/
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| f0c32f0d52 |
@ -36,30 +36,19 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Do not replace prop drilling with one top-level hook that returns a large view model and then thread that object through section props. Move each hook, query, derived value, and handler to the concrete section that consumes it, or use feature-scoped Jotai atoms for simple shared form/UI state when siblings need the same source of truth.
|
||||
- When using feature-scoped Jotai state for a form, drawer, or other secondary surface, scope the store to that surface instance when stale cross-instance state is possible. Initialize stable config at the owning boundary, then let descendants read only the atoms or purpose-named hooks they actually need.
|
||||
- For Jotai-backed surfaces, put shared query atoms, mutation atoms, derived state, and write actions in the feature state file when they coordinate multiple descendants. The lowest-owner rule still applies to independent visual surfaces that do not participate in shared state.
|
||||
- For repeated row/menu action surfaces that need reset, hydrate the stable identity at the surface entry and scope only the primitives that truly need per-instance reset, such as open flags, drafts, or selected local options.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
|
||||
## Feature-Scoped Jotai State
|
||||
|
||||
- A module's feature-local state lives in one state file for Jotai-backed features: primitive atoms, query atoms, derived atoms, write-only action atoms, mutation atoms, submission orchestration, provider exports, and optional scope configuration.
|
||||
- Keep state local when one component owns it, even inside Jotai-backed features. Dialog open flags, menu/popover visibility, confirmation visibility, form/input drafts, row-local pending flags, and in-flight refs usually belong in component state.
|
||||
- Promote UI state to an atom only when siblings need the same source of truth, the value drives a query or mutation atom, a parent workflow coordinates the state, or the state intentionally persists across hidden or unmounted descendants within a scoped surface.
|
||||
- Reflect atom-backed surface-wide locks or invariants in every affected trigger. If only one row, menu, or dialog should be disabled, keep the pending or lock state local to that row, menu, or dialog.
|
||||
- Atom order in the state file follows the dependency graph: types/constants, editable primitives, query atoms, query-data derived atoms, readiness/business derived atoms, write actions, mutation atoms, submission orchestration, provider exports.
|
||||
- Derived atom names read as business facts. Write atom names read as user or workflow commands.
|
||||
- UI components read and write the exact atom they use with `useAtomValue` or `useSetAtom`. Repeated workflow semantics live in named derived atoms or write atoms.
|
||||
- Non-query derived atoms return a narrow value with a clear domain name; avoid pass-through aliases or bundling unrelated UI facts. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
|
||||
- Write-only atoms own synchronous state transitions that update multiple primitives, reset dependent state, or advance the workflow. Async work with loading, error, caching, retry, or stale-result concerns should be modeled as query or mutation atoms, with write atoms only changing the inputs that drive them.
|
||||
- Avoid feature hooks that aggregate form values, query results, derived state, and commands for sibling components. Prefer named derived atoms and write atoms so UI components read the exact shared fact or command they need.
|
||||
- When a form library owns validation, keep submit orchestration in feature state when post-submit result or error state is shared by the surface. Avoid duplicating validation gates or request shaping in UI hooks.
|
||||
- Non-query derived atoms return a narrow value with a clear domain name. Query atoms expose the TanStack Query result object so loading, error, fetch, and pagination state stay attached to the query contract.
|
||||
- Write-only atoms own state transitions that update multiple primitives, reset dependent state, guard stale async work, or advance the workflow.
|
||||
- `jotai-tanstack-query` atoms use the same QueryClient as the React Query provider. Query atoms belong in feature state when atoms are the feature's local state surface.
|
||||
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query and mutation atoms keep shared cache behavior through the shared QueryClient.
|
||||
- Do not put `atomWithQuery`, `atomWithInfiniteQuery`, `atomWithMutation`, or broad derived orchestration atoms in a `ScopeProvider` just to reset a surface. Scoped derived atoms implicitly scope their dependencies, which can duplicate query client access and break shared invalidation. Leave query/mutation atoms unscoped; let them read scoped primitive inputs.
|
||||
- Scope providers should list resettable primitive atoms and explicit hydration tuples. If a derived atom must be scoped, confirm that every dependency it implicitly scopes is meant to be private to that surface.
|
||||
- Keep independent dialog lifecycles separate. Avoid a single discriminated "current action dialog" atom when edit, delete, and other dialogs have their own open state, loading guard, or reset behavior.
|
||||
- Route-derived stable identities that do not need instance reset or scoped isolation can be hydrated at the route or layout boundary into a feature route atom. Use scoped atoms only when stale cross-instance state or per-surface reset semantics are needed.
|
||||
- Jotai scope is an optional instance-isolation tool for secondary surfaces with independent local state. Query atoms keep shared cache behavior through the shared QueryClient.
|
||||
|
||||
## Components, Props, And Types
|
||||
|
||||
@ -82,7 +71,6 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Use generated enum objects and union types directly in props, comparisons, status logic, and i18n keys. Do not add local enum constants or parallel frontend enum/status layers unless they model real product state not represented by the API. Presentation-only tone maps should be keyed by the generated enum.
|
||||
- Normalize or coerce only at a real boundary, such as user-entered forms, search, URL/query params, file names, DOM IDs, or legacy adapters. Preserve user-entered values when whitespace or formatting can be meaningful.
|
||||
- Do not coerce nullable or optional API strings to `''` in query, derived model, or payload-building code. Keep `undefined` or `null` until the final boundary that requires a string.
|
||||
- Do not use `value || undefined` for mutation payload fields where an empty string means "clear this value". Trim or normalize at the form boundary, then preserve `''` when the API contract treats it as an intentional update.
|
||||
- Local UI models are fine for presentation, form state, select options, or guarded required-field refinements. Name them as UI concepts, not generated DTO mirrors.
|
||||
- Required-value refinements are allowed only after same-branch filtering or early return. Prefer nullable-tolerant props for render-only data.
|
||||
- When a component needs a stricter shape than a generated DTO, refine once at the API/query-to-UI boundary into a purpose-named UI type instead of hiding missing fields with generic fallback or coercion helpers.
|
||||
@ -102,17 +90,12 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
|
||||
- Keep `web/contract/*` as the single source of truth for API shape; follow existing domain/router patterns and the `{ params, query?, body? }` input shape.
|
||||
- Consume queries directly with `useQuery(consoleQuery.xxx.queryOptions(...))` or `useQuery(marketplaceQuery.xxx.queryOptions(...))`.
|
||||
- In `atomWithQuery` and `atomWithInfiniteQuery`, return generated `queryOptions()` or `infiniteOptions()` directly. Pass `enabled`, `retry`, `placeholderData`, `select`, and pagination options into that call instead of spreading generated options into a hand-built object.
|
||||
- In `atomWithMutation`, return generated `mutationOptions()` directly when using generated clients. Put request shaping and submit orchestration in write atoms; do not rebuild mutation option objects just to pass through the generated mutation function.
|
||||
- For custom query functions that do not come from generated clients, wrap the options object with TanStack `queryOptions(...)` so query atoms still return a query options contract.
|
||||
- Avoid pass-through hooks and thin `web/service/use-*` wrappers that only rename `queryOptions()` or `mutationOptions()`. Extract a small `queryOptions` helper only when repeated call-site options justify it.
|
||||
- Keep feature hooks for real orchestration, workflow state, or shared domain behavior.
|
||||
- For TanStack cache data, use generated or query-derived types; do not create local wrappers for `getQueryData` or `getQueriesData`.
|
||||
- For generated oRPC `queryOptions()` / `infiniteOptions()`, keep returning the generated options directly. When required input is missing, use a whole-input branch such as `input: condition ? validInput : skipToken` together with `enabled: Boolean(condition)` so no request runs and no fake payload is built.
|
||||
- Do not put `skipToken` inside a nested placeholder payload, such as `{ params: { appInstanceId: skipToken } }`. Do not create hand-written "missing queryOptions" objects or coerce required IDs to `''`.
|
||||
- For generated oRPC `queryOptions()` / `infiniteOptions()`, do not pass `skipToken` as `input`; keep a valid placeholder input shape and use `enabled` to gate missing required params because the OpenAPI codec encodes input eagerly.
|
||||
- Consume mutations directly with `useMutation(consoleQuery.xxx.mutationOptions(...))` or `useMutation(marketplaceQuery.xxx.mutationOptions(...))`; use oRPC clients as `mutationFn` only for custom flows.
|
||||
- Put shared cache behavior in `createTanstackQueryUtils(...experimental_defaults...)`; components may add UI feedback callbacks, but should not own shared invalidation rules.
|
||||
- Component or atom mutation callbacks can handle local UI feedback such as toasts, closing dialogs, or navigation. They should not replace shared invalidation or add local cache patches for shared server state.
|
||||
- Do not use deprecated `useInvalid` or `useReset`.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required, and wrap awaited calls in `try/catch`.
|
||||
|
||||
@ -124,9 +107,8 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- When a dialog, dropdown, or popover component already accepts controlled `open` state, mount the surface unconditionally unless unmounting is required for performance or reset semantics. Use keyed scope or local state reset for reset behavior instead of `{open && <Surface />}` wrappers.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, children-as-pass-through composition, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook, forwards props, or passes trigger/content through to one child, move the logic into that child or make the wrapper own a real surface.
|
||||
- Avoid shallow wrappers, hook-to-props adapter components, layout-only render-prop wrappers, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary. If a component only calls a hook and forwards every returned field to one child, move the hook into that child or make the wrapper own a real surface.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
@ -135,7 +117,6 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Do not use Effects to handle user actions. Put action-specific logic in the event handler where the cause is known.
|
||||
- Do not use Effects to copy one state value into another state value representing the same concept. Pick one source of truth and derive the rest during render.
|
||||
- Do not reset or adjust state from props with an Effect. Prefer a `key` reset, storing a stable ID and deriving the selected object, or guarded same-component render-time adjustment when truly necessary.
|
||||
- For forms initialized from query data, prefer keyed remounts or surface-entry hydration of form/field atoms over an Effect that copies query data into form state.
|
||||
- Prefer framework data APIs or TanStack Query for data fetching instead of writing request Effects in components.
|
||||
- If an Effect still seems necessary, first name the external system it synchronizes with. If there is no external system, remove the Effect and restructure the state or event flow.
|
||||
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -29,7 +29,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@ -91,7 +91,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@ -142,7 +142,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
run: echo "autofix.ci updates pull request branches, not merge group refs."
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@ -8,6 +8,8 @@ on:
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-backend"
|
||||
- "feat/rbac"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
12
.github/workflows/cli-e2e.yml
vendored
12
.github/workflows/cli-e2e.yml
vendored
@ -79,7 +79,7 @@ jobs:
|
||||
ws2_app_id: ${{ steps.out.outputs.DIFY_E2E_WS2_APP_ID }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
@ -123,7 +123,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
@ -170,7 +170,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
@ -233,7 +233,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
@ -295,7 +295,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
@ -351,7 +351,7 @@ jobs:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v4
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/cli-edge.yml
vendored
2
.github/workflows/cli-edge.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
||||
working-directory: ./cli
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
4
.github/workflows/cli-release.yml
vendored
4
.github/workflows/cli-release.yml
vendored
@ -35,7 +35,7 @@ jobs:
|
||||
dify_tag: ${{ steps.resolve.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -98,7 +98,7 @@ jobs:
|
||||
DIFY_TAG: ${{ needs.validate.outputs.dify_tag }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 1
|
||||
|
||||
2
.github/workflows/cli-smoke.yml
vendored
2
.github/workflows/cli-smoke.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/cli-tests.yml
vendored
2
.github/workflows/cli-tests.yml
vendored
@ -30,7 +30,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@ -63,7 +63,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
25
.github/workflows/deploy-hitl.yml
vendored
Normal file
25
.github/workflows/deploy-hitl.yml
vendored
Normal file
@ -0,0 +1,25 @@
|
||||
name: Deploy HITL
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "build/feat/hitl"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
2
.github/workflows/hotfix-cherry-pick.yml
vendored
2
.github/workflows/hotfix-cherry-pick.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
name: Require cherry-pick provenance
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/main-ci.yml
vendored
2
.github/workflows/main-ci.yml
vendored
@ -48,7 +48,7 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: changes
|
||||
with:
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Checkout default branch (trusted code)
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@fac544c07dec837d0ccb6301d7b5580bf5edae39 # v8.2.0
|
||||
|
||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -114,7 +114,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -171,7 +171,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@ -189,7 +189,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
|
||||
uses: super-linter/super-linter/slim@4ce20838b8ab83717e78138c5b3a1407148e0918 # v8.7.0
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
2
.github/workflows/tool-test-sdks.yaml
vendored
2
.github/workflows/tool-test-sdks.yaml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@806af32823ef69c8ef357086c573a902af641307 # v1.0.151
|
||||
uses: anthropics/claude-code-action@2fee15510437d71399d9139ed60433470484a8fb # v1.0.153
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
8
.github/workflows/web-tests.yml
vendored
8
.github/workflows/web-tests.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -102,7 +102,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@ -134,7 +134,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
|
||||
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@ -768,6 +768,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -25,7 +25,6 @@ from .plugin import (
|
||||
from .rbac import migrate_member_roles_to_rbac
|
||||
from .retention import (
|
||||
archive_workflow_runs,
|
||||
archive_workflow_runs_plan,
|
||||
clean_expired_messages,
|
||||
clean_workflow_runs,
|
||||
cleanup_orphaned_draft_variables,
|
||||
@ -52,7 +51,6 @@ from .vector import (
|
||||
__all__ = [
|
||||
"add_qdrant_index",
|
||||
"archive_workflow_runs",
|
||||
"archive_workflow_runs_plan",
|
||||
"backfill_plugin_auto_upgrade",
|
||||
"clean_expired_messages",
|
||||
"clean_workflow_runs",
|
||||
|
||||
@ -12,160 +12,10 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
|
||||
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
|
||||
from services.retention.workflow_run.tenant_prefix import tenant_prefix_condition
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_HEX_PREFIXES = tuple("0123456789abcdef")
|
||||
|
||||
|
||||
class WorkflowRunArchivePlanRow(TypedDict):
|
||||
tenant_prefix: str
|
||||
total_tenants: int
|
||||
workflow_runs: int
|
||||
workflow_node_executions: int
|
||||
paid_tenants: int
|
||||
unpaid_tenants: int
|
||||
|
||||
|
||||
class WorkflowRunArchiveTenantPlan(TypedDict):
|
||||
archive_tenant_ids: list[str] | None
|
||||
paid_tenant_ids: list[str]
|
||||
unpaid_tenant_ids: list[str]
|
||||
|
||||
|
||||
def _parse_tenant_prefixes(prefixes: str | None) -> list[str]:
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
parsed = []
|
||||
for raw_prefix in prefixes.split(","):
|
||||
prefix = raw_prefix.strip().lower()
|
||||
if not prefix:
|
||||
continue
|
||||
if len(prefix) != 1 or prefix not in _HEX_PREFIXES:
|
||||
raise click.UsageError("--tenant-prefixes must be a comma-separated list of hex digits, e.g. 0,1,a,f.")
|
||||
parsed.append(prefix)
|
||||
return sorted(set(parsed))
|
||||
|
||||
|
||||
def _get_archive_candidate_tenant_ids_by_prefix(
|
||||
prefix: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
) -> list[str]:
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowRun
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
conditions = [
|
||||
WorkflowRun.created_at < end_before,
|
||||
WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
|
||||
WorkflowRun.type.in_(WorkflowRunArchiver.ARCHIVED_TYPE),
|
||||
tenant_prefix_condition(WorkflowRun.tenant_id, prefix),
|
||||
]
|
||||
if start_from is not None:
|
||||
conditions.append(WorkflowRun.created_at >= start_from)
|
||||
|
||||
tenant_ids = db.session.scalars(
|
||||
sa.select(WorkflowRun.tenant_id).where(*conditions).distinct().order_by(WorkflowRun.tenant_id)
|
||||
).all()
|
||||
return list(tenant_ids)
|
||||
|
||||
|
||||
def _filter_paid_workflow_archive_tenant_ids(tenant_ids: list[str]) -> tuple[list[str], list[str]]:
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.billing_service import BillingService
|
||||
|
||||
tenant_ids = sorted(set(tenant_ids))
|
||||
if not tenant_ids:
|
||||
return [], []
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return tenant_ids, []
|
||||
|
||||
plans = BillingService.get_plan_bulk_with_cache(tenant_ids)
|
||||
paid_tenant_ids = [
|
||||
tenant_id
|
||||
for tenant_id in tenant_ids
|
||||
if plans.get(tenant_id) and plans[tenant_id].get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM)
|
||||
]
|
||||
unpaid_tenant_ids = sorted(set(tenant_ids) - set(paid_tenant_ids))
|
||||
return paid_tenant_ids, unpaid_tenant_ids
|
||||
|
||||
|
||||
def _resolve_archive_tenant_ids_from_plan(
|
||||
*,
|
||||
tenant_ids: str | None,
|
||||
tenant_prefixes: list[str],
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
) -> WorkflowRunArchiveTenantPlan:
|
||||
"""
|
||||
Resolve the archive tenant scope once before scanning workflow_runs.
|
||||
|
||||
Prefix rollout should use the tenant list collected by the same planning path, then archive by
|
||||
tenant_id IN (...). Scanning workflow_runs with a tenant prefix range in every archive run is too expensive on
|
||||
the large production table this command is meant to shrink.
|
||||
"""
|
||||
if tenant_ids:
|
||||
requested_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()]
|
||||
elif tenant_prefixes:
|
||||
requested_tenant_ids = []
|
||||
for prefix in tenant_prefixes:
|
||||
requested_tenant_ids.extend(
|
||||
_get_archive_candidate_tenant_ids_by_prefix(
|
||||
prefix,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return WorkflowRunArchiveTenantPlan(
|
||||
archive_tenant_ids=None,
|
||||
paid_tenant_ids=[],
|
||||
unpaid_tenant_ids=[],
|
||||
)
|
||||
|
||||
paid_tenant_ids, unpaid_tenant_ids = _filter_paid_workflow_archive_tenant_ids(requested_tenant_ids)
|
||||
return WorkflowRunArchiveTenantPlan(
|
||||
archive_tenant_ids=paid_tenant_ids,
|
||||
paid_tenant_ids=paid_tenant_ids,
|
||||
unpaid_tenant_ids=unpaid_tenant_ids,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_archive_time_range(
|
||||
*,
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
) -> tuple[int, datetime.datetime | None, datetime.datetime | None]:
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
raise click.UsageError("Choose either day offsets or explicit dates, not both.")
|
||||
if from_days_ago <= to_days_ago:
|
||||
raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if start_from and end_before and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be earlier than --end-before.")
|
||||
|
||||
return before_days, start_from, end_before
|
||||
|
||||
|
||||
@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.")
|
||||
@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30)
|
||||
@ -289,143 +139,11 @@ def clean_workflow_runs(
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"archive-workflow-runs-plan",
|
||||
help="Plan workflow run archive rollout by tenant ID first hex digit.",
|
||||
)
|
||||
@click.option("--before-days", default=90, show_default=True, help="Plan runs older than N days.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Plan runs created at or after this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Plan runs created before this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option(
|
||||
"--include-archived",
|
||||
is_flag=True,
|
||||
help="Compatibility no-op for V2 bundle archive; plan counts source rows in the requested window.",
|
||||
)
|
||||
def archive_workflow_runs_plan(
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
include_archived: bool,
|
||||
):
|
||||
"""
|
||||
Print the 16 tenant-prefix rollout rows used to choose archive execution order.
|
||||
|
||||
Counts use the same workflow run eligibility as archive-workflow-runs: ended runs,
|
||||
supported workflow types, and the requested created_at window. V2 bundle archive
|
||||
does not maintain per-run archive logs, so this plan reports source-table volume.
|
||||
"""
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
before_days, start_from, end_before = _resolve_archive_time_range(
|
||||
before_days=before_days,
|
||||
from_days_ago=from_days_ago,
|
||||
to_days_ago=to_days_ago,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
)
|
||||
plan_end_before = end_before or datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=before_days)
|
||||
if include_archived:
|
||||
click.echo(click.style("--include-archived is a no-op for V2 bundle archive plans.", fg="yellow"))
|
||||
|
||||
rows: list[WorkflowRunArchivePlanRow] = []
|
||||
for prefix in _HEX_PREFIXES:
|
||||
tenant_ids = _get_archive_candidate_tenant_ids_by_prefix(
|
||||
prefix,
|
||||
start_from=start_from,
|
||||
end_before=plan_end_before,
|
||||
)
|
||||
total_tenants = len(tenant_ids)
|
||||
paid_tenant_ids, unpaid_tenant_ids = _filter_paid_workflow_archive_tenant_ids(tenant_ids)
|
||||
|
||||
run_conditions = [
|
||||
WorkflowRun.created_at < plan_end_before,
|
||||
WorkflowRun.status.in_(WorkflowExecutionStatus.ended_values()),
|
||||
WorkflowRun.type.in_(WorkflowRunArchiver.ARCHIVED_TYPE),
|
||||
tenant_prefix_condition(WorkflowRun.tenant_id, prefix),
|
||||
]
|
||||
if start_from is not None:
|
||||
run_conditions.append(WorkflowRun.created_at >= start_from)
|
||||
workflow_runs = (
|
||||
db.session.scalar(sa.select(sa.func.count()).select_from(WorkflowRun).where(*run_conditions)) or 0
|
||||
)
|
||||
candidate_runs = sa.select(WorkflowRun.id).where(*run_conditions).subquery()
|
||||
workflow_node_executions = (
|
||||
db.session.scalar(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(WorkflowNodeExecutionModel)
|
||||
.join(candidate_runs, WorkflowNodeExecutionModel.workflow_run_id == candidate_runs.c.id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
rows.append(
|
||||
WorkflowRunArchivePlanRow(
|
||||
tenant_prefix=prefix,
|
||||
total_tenants=total_tenants,
|
||||
workflow_runs=workflow_runs,
|
||||
workflow_node_executions=workflow_node_executions,
|
||||
paid_tenants=len(paid_tenant_ids),
|
||||
unpaid_tenants=len(unpaid_tenant_ids),
|
||||
)
|
||||
)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Workflow archive plan for runs before {plan_end_before.isoformat()}"
|
||||
f"{f' and at/after {start_from.isoformat()}' if start_from else ''}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
click.echo("tenant_prefix,total_tenants,workflow_runs,workflow_node_executions,paid_tenants,unpaid_tenants")
|
||||
for row in rows:
|
||||
click.echo(
|
||||
f"{row['tenant_prefix']},{row['total_tenants']},{row['workflow_runs']},"
|
||||
f"{row['workflow_node_executions']},{row['paid_tenants']},{row['unpaid_tenants']}"
|
||||
)
|
||||
|
||||
ordered_rows = sorted(
|
||||
rows,
|
||||
key=lambda row: (row["workflow_runs"] + row["workflow_node_executions"], row["tenant_prefix"]),
|
||||
)
|
||||
click.echo("suggested_execution_order=" + ",".join(row["tenant_prefix"] for row in ordered_rows))
|
||||
|
||||
|
||||
@click.command(
|
||||
"archive-workflow-runs",
|
||||
help="Archive workflow runs for paid plan tenants to S3-compatible storage.",
|
||||
)
|
||||
@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.")
|
||||
@click.option(
|
||||
"--tenant-prefixes",
|
||||
default=None,
|
||||
help="Optional comma-separated tenant ID first hex digits for rollout waves, e.g. 0,1,a,f.",
|
||||
)
|
||||
@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
@ -451,36 +169,13 @@ def archive_workflow_runs_plan(
|
||||
default=None,
|
||||
help="Archive runs created before this timestamp (UTC if no timezone).",
|
||||
)
|
||||
@click.option("--batch-size", default=100, show_default=True, help="Maximum workflow runs per archive bundle.")
|
||||
@click.option(
|
||||
"--workers",
|
||||
default=1,
|
||||
show_default=True,
|
||||
type=int,
|
||||
help="Reserved; bundle archive currently runs serially.",
|
||||
)
|
||||
@click.option(
|
||||
"--run-shard-index",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Zero-based workflow run shard index for parallel cron jobs. Must be paired with --run-shard-total.",
|
||||
)
|
||||
@click.option(
|
||||
"--run-shard-total",
|
||||
default=None,
|
||||
type=click.IntRange(min=1, max=16),
|
||||
help="Total workflow run shard count for parallel cron jobs. Must be paired with --run-shard-index.",
|
||||
)
|
||||
@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.")
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.")
|
||||
@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without archiving.")
|
||||
@click.option(
|
||||
"--delete-after-archive",
|
||||
is_flag=True,
|
||||
help="Not supported by bundle archive; use a separate bundle delete workflow after validation.",
|
||||
)
|
||||
@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.")
|
||||
def archive_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
tenant_prefixes: str | None,
|
||||
before_days: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
@ -488,8 +183,6 @@ def archive_workflow_runs(
|
||||
end_before: datetime.datetime | None,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
run_shard_index: int | None,
|
||||
run_shard_total: int | None,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
delete_after_archive: bool,
|
||||
@ -497,19 +190,14 @@ def archive_workflow_runs(
|
||||
"""
|
||||
Archive workflow runs for paid plan tenants older than the specified days.
|
||||
|
||||
This command writes V2 tenant/month/shard archive bundles. Each bundle contains Parquet snapshots from:
|
||||
- workflow_runs
|
||||
- workflow_app_logs
|
||||
This command archives the following tables to storage:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
|
||||
Source database rows are always preserved by archive. Deletion must be handled by
|
||||
a separate bundle-level delete workflow after manifest, checksum, row-count, and
|
||||
restore-sampling validation. In --dry-run mode, no storage or database writes
|
||||
happen; the command estimates per-table Parquet bytes and object size instead.
|
||||
The workflow_runs and workflow_app_logs tables are preserved for UI listing.
|
||||
"""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
@ -521,58 +209,32 @@ def archive_workflow_runs(
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
before_days, start_from, end_before = _resolve_archive_time_range(
|
||||
before_days=before_days,
|
||||
from_days_ago=from_days_ago,
|
||||
to_days_ago=to_days_ago,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
)
|
||||
parsed_tenant_prefixes = _parse_tenant_prefixes(tenant_prefixes)
|
||||
except click.UsageError as e:
|
||||
click.echo(click.style(e.message, fg="red"))
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
click.echo(click.style("start-from and end-before must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red"))
|
||||
return
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red"))
|
||||
return
|
||||
if from_days_ago <= to_days_ago:
|
||||
click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red"))
|
||||
return
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if start_from and end_before and start_from >= end_before:
|
||||
click.echo(click.style("start-from must be earlier than end-before.", fg="red"))
|
||||
return
|
||||
plan_end_before = end_before or datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=before_days)
|
||||
if workers < 1:
|
||||
click.echo(click.style("workers must be at least 1.", fg="red"))
|
||||
return
|
||||
if (run_shard_index is None) ^ (run_shard_total is None):
|
||||
click.echo(click.style("run-shard-index and run-shard-total must be provided together.", fg="red"))
|
||||
return
|
||||
if run_shard_index is not None and run_shard_total is not None and run_shard_index >= run_shard_total:
|
||||
click.echo(click.style("run-shard-index must be less than run-shard-total.", fg="red"))
|
||||
return
|
||||
if delete_after_archive:
|
||||
click.echo(click.style("delete-after-archive is not supported by bundle archive.", fg="red"))
|
||||
return
|
||||
|
||||
try:
|
||||
tenant_plan = _resolve_archive_tenant_ids_from_plan(
|
||||
tenant_ids=tenant_ids,
|
||||
tenant_prefixes=parsed_tenant_prefixes,
|
||||
start_from=start_from,
|
||||
end_before=plan_end_before,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve workflow archive tenant plan")
|
||||
click.echo(click.style("Failed to resolve workflow archive tenant plan.", fg="red"))
|
||||
return
|
||||
|
||||
planned_tenant_ids = tenant_plan["archive_tenant_ids"]
|
||||
planned_paid_tenant_ids = tenant_plan["paid_tenant_ids"] if planned_tenant_ids is not None else None
|
||||
paid_tenants = len(tenant_plan["paid_tenant_ids"])
|
||||
unpaid_tenants = len(tenant_plan["unpaid_tenant_ids"])
|
||||
if planned_tenant_ids is not None:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Resolved archive tenant plan: paid_tenants={paid_tenants}, unpaid_tenants={unpaid_tenants}.",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
if not planned_tenant_ids:
|
||||
click.echo(click.style("No paid tenants matched the archive plan; nothing to archive.", fg="yellow"))
|
||||
return
|
||||
|
||||
archiver = WorkflowRunArchiver(
|
||||
days=before_days,
|
||||
@ -580,11 +242,7 @@ def archive_workflow_runs(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
workers=workers,
|
||||
tenant_ids=planned_tenant_ids,
|
||||
tenant_prefixes=parsed_tenant_prefixes,
|
||||
paid_tenant_ids=planned_paid_tenant_ids,
|
||||
run_shard_index=run_shard_index,
|
||||
run_shard_total=run_shard_total,
|
||||
tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None,
|
||||
limit=limit,
|
||||
dry_run=dry_run,
|
||||
delete_after_archive=delete_after_archive,
|
||||
@ -594,9 +252,7 @@ def archive_workflow_runs(
|
||||
click.style(
|
||||
f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
|
||||
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
|
||||
f"bundles_archived={summary.bundles_archived}, bundles_skipped={summary.bundles_skipped}, "
|
||||
f"bundles_failed={summary.bundles_failed}, "
|
||||
f"object_size_bytes={summary.total_object_size_bytes}, time={summary.total_elapsed_time:.2f}s",
|
||||
f"time={summary.total_elapsed_time:.2f}s",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
@ -612,52 +268,6 @@ def archive_workflow_runs(
|
||||
)
|
||||
|
||||
|
||||
def _echo_bundle_archive_operation_summary(summary) -> None:
|
||||
status = "completed successfully" if summary.bundles_failed == 0 else "completed with failures"
|
||||
fg = "green" if summary.bundles_failed == 0 else "red"
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{summary.operation} {status}. "
|
||||
f"bundles_success={summary.bundles_succeeded} bundles_failed={summary.bundles_failed} "
|
||||
f"runs={summary.runs_processed} rows={summary.rows_processed} "
|
||||
f"archive_bytes={summary.archive_bytes} duration={summary.elapsed_time:.2f}s "
|
||||
f"validation_time={summary.validation_time:.2f}s "
|
||||
f"runs_per_second={summary.runs_per_second:.2f} rows_per_second={summary.rows_per_second:.2f} "
|
||||
f"bytes_per_second={summary.bytes_per_second:.2f}",
|
||||
fg=fg,
|
||||
)
|
||||
)
|
||||
click.echo(click.style("table,row_count", fg="white"))
|
||||
for table_name in [
|
||||
"workflow_runs",
|
||||
"workflow_app_logs",
|
||||
"workflow_node_executions",
|
||||
"workflow_node_execution_offload",
|
||||
"workflow_pauses",
|
||||
"workflow_pause_reasons",
|
||||
"workflow_trigger_logs",
|
||||
]:
|
||||
click.echo(f"{table_name},{summary.table_counts.get(table_name, 0)}")
|
||||
for result in summary.results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f" bundle={result.bundle_id} tenant={result.tenant_id} runs={result.run_count} "
|
||||
f"rows={result.row_count} archive_bytes={result.archive_bytes} "
|
||||
f"time={result.elapsed_time:.2f}s validation={result.validation_time:.2f}s",
|
||||
fg="white",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f" failed bundle={result.bundle_id} tenant={result.tenant_id} "
|
||||
f"object_prefix={result.object_prefix} error={result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
"restore-workflow-runs",
|
||||
help="Restore archived workflow runs from S3-compatible storage.",
|
||||
@ -680,8 +290,8 @@ def _echo_bundle_archive_operation_summary(summary) -> None:
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="V1 --run-id compatibility only.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of V2 bundles to restore.")
|
||||
@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without restoring.")
|
||||
def restore_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
@ -693,18 +303,15 @@ def restore_workflow_runs(
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Restore archived workflow runs from storage to the database.
|
||||
Restore an archived workflow run from storage to the database.
|
||||
|
||||
Batch restore uses V2 bundle metadata and validates archive objects before writing source rows. This restores:
|
||||
- workflow_runs
|
||||
- workflow_app_logs
|
||||
This restores the following tables:
|
||||
- workflow_node_executions
|
||||
- workflow_node_execution_offload
|
||||
- workflow_pauses
|
||||
- workflow_pause_reasons
|
||||
- workflow_trigger_logs
|
||||
"""
|
||||
from services.retention.workflow_run.bundle_archive_maintenance import WorkflowRunBundleArchiveMaintenance
|
||||
from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore
|
||||
|
||||
parsed_tenant_ids = None
|
||||
@ -728,46 +335,39 @@ def restore_workflow_runs(
|
||||
)
|
||||
)
|
||||
|
||||
restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
|
||||
if run_id:
|
||||
restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers)
|
||||
results = [restorer.restore_by_run_id(run_id)]
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if workers != 1:
|
||||
click.echo(
|
||||
click.style("--workers is ignored for V2 bundle restore; bundles are processed serially.", fg="yellow")
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = restorer.restore_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
bundle_restorer = WorkflowRunBundleArchiveMaintenance(dry_run=dry_run, strict_content_validation=True)
|
||||
summary = bundle_restorer.restore_batch(
|
||||
tenant_ids=parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
_echo_bundle_archive_operation_summary(summary)
|
||||
return
|
||||
|
||||
|
||||
@click.command(
|
||||
@ -792,20 +392,8 @@ def restore_workflow_runs(
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
|
||||
)
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of V2 bundles to delete.")
|
||||
@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.")
|
||||
@click.option("--dry-run", is_flag=True, help="Preview without deleting.")
|
||||
@click.option(
|
||||
"--skip-bad-archives",
|
||||
is_flag=True,
|
||||
help="Continue batch deletion when one archive object fails validation.",
|
||||
)
|
||||
@click.option(
|
||||
"--restore-sample-interval",
|
||||
type=int,
|
||||
default=0,
|
||||
show_default=True,
|
||||
help="Run restore dry-run after every N successful deletes; 0 disables restore sampling.",
|
||||
)
|
||||
def delete_archived_workflow_runs(
|
||||
tenant_ids: str | None,
|
||||
run_id: str | None,
|
||||
@ -813,16 +401,10 @@ def delete_archived_workflow_runs(
|
||||
end_before: datetime.datetime | None,
|
||||
limit: int,
|
||||
dry_run: bool,
|
||||
skip_bad_archives: bool,
|
||||
restore_sample_interval: int,
|
||||
):
|
||||
"""
|
||||
Delete archived workflow runs from the database.
|
||||
|
||||
Batch delete uses V2 bundle metadata and validates object existence, manifest schema, object size, checksum, row
|
||||
counts, and source/archive content checksums before deleting source rows. `--run-id` keeps the V1 per-run path.
|
||||
"""
|
||||
from services.retention.workflow_run.bundle_archive_maintenance import WorkflowRunBundleArchiveMaintenance
|
||||
from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion
|
||||
|
||||
parsed_tenant_ids = None
|
||||
@ -835,8 +417,6 @@ def delete_archived_workflow_runs(
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
if run_id is None and (start_from is None or end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before are required for batch delete.")
|
||||
if restore_sample_interval < 0:
|
||||
raise click.BadParameter("restore-sample-interval must be >= 0")
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
target_desc = f"workflow run {run_id}" if run_id else "workflow runs"
|
||||
@ -847,85 +427,56 @@ def delete_archived_workflow_runs(
|
||||
)
|
||||
)
|
||||
|
||||
deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run)
|
||||
if run_id:
|
||||
deleter = ArchivedWorkflowRunDeletion(
|
||||
dry_run=dry_run,
|
||||
skip_bad_archives=skip_bad_archives,
|
||||
restore_sample_interval=restore_sample_interval,
|
||||
)
|
||||
results = [deleter.delete_by_run_id(run_id)]
|
||||
for result in results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
|
||||
f"workflow run {result.run_id} (tenant={result.tenant_id}, "
|
||||
f"archive_key={result.archive_key}, counts={result.validated_counts})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
if result.restore_sampled:
|
||||
sample_status = "passed" if result.restore_sample_success else "failed"
|
||||
click.echo(
|
||||
click.style(
|
||||
f" restore dry-run sample {sample_status} for workflow run {result.run_id}",
|
||||
fg="green" if result.restore_sample_success else "red",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Failed to delete workflow run {result.run_id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
" runbook: pause this delete window, verify archive storage object and manifest/checksum, "
|
||||
"retry the same run after fixing storage or DB drift, or rerun with --skip-bad-archives "
|
||||
"to quarantine this run and continue the batch.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
results = deleter.delete_batch(
|
||||
parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
for result in results:
|
||||
if result.success:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed successfully. success={successes} duration={elapsed}",
|
||||
f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} "
|
||||
f"workflow run {result.run_id} (tenant={result.tenant_id})",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
f"Failed to delete workflow run {result.run_id}: {result.error}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if restore_sample_interval:
|
||||
click.echo(click.style("--restore-sample-interval is ignored for V2 bundle delete.", fg="yellow"))
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
bundle_deleter = WorkflowRunBundleArchiveMaintenance(
|
||||
dry_run=dry_run,
|
||||
strict_content_validation=True,
|
||||
stop_on_error=not skip_bad_archives,
|
||||
)
|
||||
summary = bundle_deleter.delete_batch(
|
||||
tenant_ids=parsed_tenant_ids,
|
||||
start_date=start_from,
|
||||
end_date=end_before,
|
||||
limit=limit,
|
||||
)
|
||||
_echo_bundle_archive_operation_summary(summary)
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
|
||||
successes = sum(1 for result in results if result.success)
|
||||
failures = len(results) - successes
|
||||
|
||||
if failures == 0:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed successfully. success={successes} duration={elapsed}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@ -1,107 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from services.enterprise import rbac_service as enterprise_rbac_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.app_service import AppListBaseParams
|
||||
from services.enterprise.rbac_service import MyPermissionsResponse
|
||||
|
||||
# Permission keys (dot-notation, from MyPermissionsResponse) that grant
|
||||
# list/preview access to an app. Keep this the single source of truth for both
|
||||
# the console and OpenAPI app-list endpoints.
|
||||
APP_LIST_PERMISSION_KEYS: frozenset[str] = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
|
||||
|
||||
# Workspace permission key that lets a caller see apps they maintain even when
|
||||
# those apps are not in their preview whitelist.
|
||||
_MANAGE_OWN_APPS_PERMISSION_KEY = "app.create_and_management"
|
||||
|
||||
|
||||
def has_app_list_permission(permission_keys: Sequence[str]) -> bool:
|
||||
"""Return True if any of ``permission_keys`` grants app list/preview access."""
|
||||
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppAccessFilter:
|
||||
"""Resolved RBAC visibility for app list/read endpoints.
|
||||
|
||||
``accessible_app_ids`` of ``None`` means the caller can see every app in the
|
||||
workspace (unrestricted). Otherwise it is the exact set of app ids the
|
||||
caller may preview; combined with ``can_manage_own_apps`` it also covers
|
||||
apps the caller maintains.
|
||||
"""
|
||||
|
||||
accessible_app_ids: set[str] | None
|
||||
can_manage_own_apps: bool
|
||||
|
||||
@classmethod
|
||||
def unrestricted(cls) -> AppAccessFilter:
|
||||
"""Filter that imposes no restriction (RBAC disabled / not applicable)."""
|
||||
return cls(accessible_app_ids=None, can_manage_own_apps=False)
|
||||
|
||||
def is_app_accessible(self, app_id: str, maintainer: str | None, account_id: str) -> bool:
|
||||
"""Whether a single app is visible to the caller under this filter.
|
||||
|
||||
Mirrors the service-layer query gate: an app is visible when the filter
|
||||
is unrestricted, the app id is whitelisted, or the caller maintains it
|
||||
and holds ``app.create_and_management``.
|
||||
"""
|
||||
if self.accessible_app_ids is None:
|
||||
return True
|
||||
if app_id in self.accessible_app_ids:
|
||||
return True
|
||||
return self.can_manage_own_apps and maintainer is not None and maintainer == account_id
|
||||
|
||||
def apply_to_params(self, params: AppListBaseParams) -> None:
|
||||
if self.accessible_app_ids is None:
|
||||
return
|
||||
params.accessible_app_ids = sorted(self.accessible_app_ids)
|
||||
params.include_own_apps = self.can_manage_own_apps
|
||||
|
||||
|
||||
def resolve_app_access_filter(
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
*,
|
||||
permissions: MyPermissionsResponse | None = None,
|
||||
) -> AppAccessFilter:
|
||||
"""Compute the RBAC app-access filter for ``account_id`` in ``tenant_id``.
|
||||
|
||||
Pass ``permissions`` when the caller has already fetched the snapshot (the
|
||||
console controller reuses it for per-app permission keys) to avoid a second
|
||||
inner-API round trip; otherwise it is fetched here.
|
||||
"""
|
||||
if permissions is None:
|
||||
permissions = enterprise_rbac_service.RBACService.MyPermissions.get(tenant_id, account_id)
|
||||
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(tenant_id, account_id)
|
||||
|
||||
can_manage_own_apps = _MANAGE_OWN_APPS_PERMISSION_KEY in permissions.workspace.permission_keys
|
||||
has_default_preview = has_app_list_permission(permissions.app.default_permission_keys) or has_app_list_permission(
|
||||
permissions.workspace.permission_keys
|
||||
)
|
||||
|
||||
permission_app_ids: set[str] | None = None
|
||||
if not has_default_preview:
|
||||
# Collect apps the caller can preview via per-app permission overrides.
|
||||
permission_app_ids = {
|
||||
override.resource_id
|
||||
for override in permissions.app.overrides
|
||||
if has_app_list_permission(override.permission_keys)
|
||||
}
|
||||
|
||||
accessible_app_ids: set[str] | None
|
||||
if getattr(whitelist_scope, "unrestricted", False):
|
||||
accessible_app_ids = permission_app_ids
|
||||
else:
|
||||
accessible_app_ids = set(whitelist_scope.resource_ids)
|
||||
if permission_app_ids is not None:
|
||||
accessible_app_ids |= permission_app_ids
|
||||
elif has_default_preview:
|
||||
# Default preview overrides the whitelist restriction.
|
||||
accessible_app_ids = None
|
||||
|
||||
return AppAccessFilter(accessible_app_ids=accessible_app_ids, can_manage_own_apps=can_manage_own_apps)
|
||||
@ -1,3 +1,23 @@
|
||||
"""Shared decorator utilities for Dify controller layers.
|
||||
|
||||
This module provides decorators that are not tied to any single API group (e.g.
|
||||
console, inner, service). Currently it exposes the RBAC permission gate, which
|
||||
can be applied to any blueprint.
|
||||
|
||||
Key exports
|
||||
-----------
|
||||
``rbac_permission_required`` – decorator that enforces enterprise RBAC access
|
||||
control. When ``RBAC_ENABLED`` is ``False`` it is a no-op.
|
||||
|
||||
``RBACPermission``, ``RBACResourceScope`` – re-exported from ``core.rbac`` so
|
||||
callers only need a single import site.
|
||||
|
||||
Private helpers
|
||||
---------------
|
||||
``_extract_resource_id``, ``_is_resource_owned_by_current_user`` – kept module-
|
||||
private but accessible via the module namespace for unit-test patching.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
@ -12,57 +32,7 @@ from models.dataset import Dataset
|
||||
from models.model import App
|
||||
from services.enterprise.rbac_service import RBACService
|
||||
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "enforce_rbac_access", "rbac_permission_required"]
|
||||
|
||||
|
||||
def enforce_rbac_access(
|
||||
*,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
resource_type: RBACResourceScope,
|
||||
scene: RBACPermission,
|
||||
resource_required: bool = True,
|
||||
path_args: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
"""Enforce enterprise RBAC for an explicit account/tenant pair.
|
||||
|
||||
This is the flask-login-independent core of the RBAC gate so it can run
|
||||
inside request-handling layers that resolve the caller themselves (e.g. the
|
||||
openapi auth pipeline, which has the account on ``AuthData`` before
|
||||
flask-login is mounted).
|
||||
|
||||
No-op when ``RBAC_ENABLED`` is ``False``. For resource-scoped checks the
|
||||
resource ID is taken from ``path_args`` merged with ``request.view_args``;
|
||||
resource ownership short-circuits the check. Raises ``Forbidden`` when
|
||||
access is denied. For workspace-level checks pass ``resource_required=False``
|
||||
so the RBAC request omits ``resource_id``.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant the access is evaluated against.
|
||||
account_id: The account requesting access.
|
||||
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
|
||||
scene: The :class:`RBACPermission` permission point, e.g. ``RBACPermission.APP_DELETE``.
|
||||
resource_required: Whether a concrete resource ID is required.
|
||||
path_args: Extra path arguments to merge with ``request.view_args``.
|
||||
"""
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return
|
||||
|
||||
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
|
||||
resource_id = None
|
||||
if resource_required and check_resource_type:
|
||||
resource_id = _extract_resource_id(resource_type, path_args)
|
||||
if _is_resource_owned_by_current_user(tenant_id, account_id, resource_type, resource_id):
|
||||
return
|
||||
allowed = RBACService.CheckAccess.check(
|
||||
tenant_id,
|
||||
account_id,
|
||||
scene=scene,
|
||||
resource_type=check_resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
if not allowed:
|
||||
raise Forbidden()
|
||||
__all__ = ["RBACPermission", "RBACResourceScope", "rbac_permission_required"]
|
||||
|
||||
|
||||
def rbac_permission_required[**P, R](
|
||||
@ -71,12 +41,14 @@ def rbac_permission_required[**P, R](
|
||||
*,
|
||||
resource_required: bool = True,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Check enterprise RBAC permissions for the current flask-login user.
|
||||
"""Check enterprise RBAC permissions for the current user.
|
||||
|
||||
When ``RBAC_ENABLED`` is ``False`` the decorator is a no-op and the
|
||||
request passes through unchanged. When enabled it resolves the current
|
||||
account/tenant and delegates to :func:`enforce_rbac_access`, raising
|
||||
``Forbidden`` if access is denied.
|
||||
request passes through unchanged. When enabled it extracts the resource ID
|
||||
from ``request.view_args`` for resource-scoped checks, calls the RBAC
|
||||
service ``check-access`` endpoint, and raises ``Forbidden`` if the access
|
||||
is denied. For workspace-level checks, set ``resource_required=False`` so
|
||||
the RBAC request omits ``resource_id``.
|
||||
|
||||
Args:
|
||||
resource_type: The :class:`RBACResourceScope` member (app/dataset/workspace).
|
||||
@ -91,14 +63,23 @@ def rbac_permission_required[**P, R](
|
||||
return view(*args, **kwargs)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
enforce_rbac_access(
|
||||
tenant_id=current_tenant_id,
|
||||
account_id=current_user.id,
|
||||
resource_type=resource_type,
|
||||
check_resource_type = None if resource_type == RBACResourceScope.WORKSPACE else resource_type
|
||||
resource_id = None
|
||||
if resource_required and check_resource_type:
|
||||
resource_id = _extract_resource_id(resource_type, kwargs)
|
||||
if _is_resource_owned_by_current_user(current_tenant_id, current_user.id, resource_type, resource_id):
|
||||
return view(*args, **kwargs)
|
||||
allowed = RBACService.CheckAccess.check(
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
scene=scene,
|
||||
resource_required=resource_required,
|
||||
path_args=kwargs,
|
||||
resource_type=check_resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise Forbidden()
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -3,17 +3,16 @@ from uuid import UUID
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.agent.app_helpers import resolve_agent_app_model
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList, BaseApiKeyListResource, BaseApiKeyResource
|
||||
from controllers.console.app.app import (
|
||||
AppDetailWithSite as GenericAppDetailWithSite,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
AppListQuery,
|
||||
CopyAppPayload,
|
||||
_normalize_app_list_query_args,
|
||||
)
|
||||
from controllers.console.app.app import (
|
||||
@ -26,13 +25,9 @@ from controllers.console.app.app import (
|
||||
UpdateAppPayload as GenericUpdateAppPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
RBACPermission,
|
||||
RBACResourceScope,
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
rbac_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
@ -41,7 +36,6 @@ from extensions.ext_database import db
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
@ -54,8 +48,7 @@ from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App, IconType
|
||||
from models.model import IconType
|
||||
from services.agent.errors import AgentNotFoundError
|
||||
from services.agent.observability_service import (
|
||||
AgentLogQueryParams,
|
||||
@ -109,46 +102,6 @@ class AgentAppUpdatePayload(GenericUpdateAppPayload):
|
||||
return role
|
||||
|
||||
|
||||
class AgentAppCopyPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied agent")
|
||||
description: str | None = Field(default=None, description="Description for the copied agent", max_length=400)
|
||||
role: str | None = Field(default=None, description="Role for the copied agent", max_length=255)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
role = value.strip()
|
||||
if not role:
|
||||
raise ValueError("Agent role is required when provided.")
|
||||
return role
|
||||
|
||||
|
||||
class AgentApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable Agent service API")
|
||||
|
||||
|
||||
class AgentApiAccessResponse(BaseModel):
|
||||
enabled: bool
|
||||
service_api_base_url: str
|
||||
streaming_only: bool = True
|
||||
chat_endpoint: str
|
||||
stop_endpoint: str
|
||||
conversations_endpoint: str
|
||||
messages_endpoint: str
|
||||
files_upload_endpoint: str
|
||||
parameters_endpoint: str
|
||||
info_endpoint: str
|
||||
meta_endpoint: str
|
||||
api_rpm: int
|
||||
api_rph: int
|
||||
api_key_count: int
|
||||
|
||||
|
||||
class AgentAppPublishedReferenceResponse(BaseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
@ -232,7 +185,6 @@ class AgentStatisticsQuery(BaseModel):
|
||||
|
||||
class AgentAppPartial(GenericAppPartial):
|
||||
app_id: str | None = None
|
||||
debug_conversation_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
published_reference_count: int = 0
|
||||
@ -241,15 +193,10 @@ class AgentAppPartial(GenericAppPartial):
|
||||
|
||||
class AgentAppDetailWithSite(GenericAppDetailWithSite):
|
||||
app_id: str | None = None
|
||||
debug_conversation_id: str | None = None
|
||||
role: str | None = None
|
||||
active_config_is_published: bool = False
|
||||
|
||||
|
||||
class AgentDebugConversationRefreshResponse(BaseModel):
|
||||
debug_conversation_id: str
|
||||
|
||||
|
||||
class AgentAppPagination(GenericAppPagination):
|
||||
data: list[AgentAppPartial] = Field( # type: ignore[assignment] # pyrefly: ignore[bad-override-mutable-attribute]
|
||||
validation_alias=AliasChoices("items", "data")
|
||||
@ -260,8 +207,7 @@ register_schema_models(
|
||||
console_ns,
|
||||
AgentAppCreatePayload,
|
||||
AgentAppUpdatePayload,
|
||||
AgentAppCopyPayload,
|
||||
AgentApiStatusPayload,
|
||||
CopyAppPayload,
|
||||
AgentInviteOptionsQuery,
|
||||
AgentLogsQuery,
|
||||
AgentStatisticsQuery,
|
||||
@ -272,14 +218,11 @@ register_schema_models(
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppPagination,
|
||||
AgentApiAccessResponse,
|
||||
AgentAppPublishedReferenceResponse,
|
||||
AgentAppDetailWithSite,
|
||||
AgentAppPartial,
|
||||
AgentDebugConversationRefreshResponse,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentLogListResponse,
|
||||
AgentLogMessageListResponse,
|
||||
@ -294,7 +237,7 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
return AgentRosterService(db.session)
|
||||
|
||||
|
||||
def _serialize_agent_app_detail(app_model, *, current_user: Account) -> dict:
|
||||
def _serialize_agent_app_detail(app_model) -> dict:
|
||||
"""Serialize an Agent App detail using roster-only DTOs.
|
||||
|
||||
`/agent` responses are roster-shaped rather than raw app-shaped: `id`
|
||||
@ -317,11 +260,6 @@ def _serialize_agent_app_detail(app_model, *, current_user: Account) -> dict:
|
||||
payload.pop("bound_agent_id", None)
|
||||
payload["app_id"] = str(app_model.id)
|
||||
payload["id"] = agent.id
|
||||
payload["debug_conversation_id"] = roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
agent_id=agent.id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
payload["role"] = agent.role or ""
|
||||
payload["active_config_is_published"] = roster_service.active_config_is_published(
|
||||
tenant_id=app_model.tenant_id,
|
||||
@ -330,7 +268,7 @@ def _serialize_agent_app_detail(app_model, *, current_user: Account) -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str, current_user: Account) -> dict:
|
||||
def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str) -> dict:
|
||||
"""Serialize Agent App lists with roster-shaped items.
|
||||
|
||||
Each item starts from the shared App list shape, then drops
|
||||
@ -353,11 +291,6 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str, current_u
|
||||
tenant_id=tenant_id,
|
||||
agent_ids=[agent.id for agent in agents_by_app_id.values()],
|
||||
)
|
||||
debug_conversation_ids_by_agent_id = roster_service.load_or_create_agent_app_debug_conversation_ids_by_agent_id(
|
||||
tenant_id=tenant_id,
|
||||
agents=list(agents_by_app_id.values()),
|
||||
account_id=current_user.id,
|
||||
)
|
||||
payload = AgentAppPagination.model_validate(app_pagination, from_attributes=True).model_dump(mode="json")
|
||||
for item in payload["data"]:
|
||||
app_id = item["id"]
|
||||
@ -366,7 +299,6 @@ def _serialize_agent_app_pagination(app_pagination, *, tenant_id: str, current_u
|
||||
if agent:
|
||||
item["app_id"] = app_id
|
||||
item["id"] = agent.id
|
||||
item["debug_conversation_id"] = debug_conversation_ids_by_agent_id.get(agent.id)
|
||||
item["role"] = agent.role or ""
|
||||
item["active_config_is_published"] = active_config_is_published_by_agent_id.get(agent.id, False)
|
||||
published_references = published_references_by_agent_id.get(agent.id, [])
|
||||
@ -391,38 +323,6 @@ def _resolve_agent_app_model(*, tenant_id: str, agent_id: UUID):
|
||||
return resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
|
||||
|
||||
def _agent_api_key_count(app_id: str) -> int:
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == ApiTokenType.APP,
|
||||
ApiToken.app_id == app_id,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _serialize_agent_api_access(app_model: App) -> dict:
|
||||
base_url = app_model.api_base_url
|
||||
response = AgentApiAccessResponse(
|
||||
enabled=bool(app_model.enable_api),
|
||||
service_api_base_url=base_url,
|
||||
chat_endpoint=f"{base_url}/chat-messages",
|
||||
stop_endpoint=f"{base_url}/chat-messages/{{task_id}}/stop",
|
||||
conversations_endpoint=f"{base_url}/conversations",
|
||||
messages_endpoint=f"{base_url}/messages",
|
||||
files_upload_endpoint=f"{base_url}/files/upload",
|
||||
parameters_endpoint=f"{base_url}/parameters",
|
||||
info_endpoint=f"{base_url}/info",
|
||||
meta_endpoint=f"{base_url}/meta",
|
||||
api_rpm=app_model.api_rpm or 0,
|
||||
api_rph=app_model.api_rph or 0,
|
||||
api_key_count=_agent_api_key_count(str(app_model.id)),
|
||||
)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
|
||||
def _agent_observability_service() -> AgentObservabilityService:
|
||||
return AgentObservabilityService(db.session)
|
||||
|
||||
@ -474,11 +374,7 @@ class AgentAppListApi(Resource):
|
||||
empty = AgentAppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json")
|
||||
|
||||
return _serialize_agent_app_pagination(
|
||||
app_pagination,
|
||||
tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
return _serialize_agent_app_pagination(app_pagination, tenant_id=current_tenant_id)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent app created successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@ -503,7 +399,7 @@ class AgentAppListApi(Resource):
|
||||
)
|
||||
|
||||
app = AppService().create_app(current_tenant_id, params, current_user)
|
||||
return _serialize_agent_app_detail(app, current_user=current_user), 201
|
||||
return _serialize_agent_app_detail(app), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>")
|
||||
@ -513,11 +409,10 @@ class AgentAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _serialize_agent_app_detail(app_model, current_user=current_user)
|
||||
return _serialize_agent_app_detail(app_model)
|
||||
|
||||
@console_ns.expect(console_ns.models[AgentAppUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent app updated successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@ -527,9 +422,8 @@ class AgentAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
def put(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
args = AgentAppUpdatePayload.model_validate(console_ns.payload)
|
||||
args_dict: AppService.ArgsDict = {
|
||||
@ -543,7 +437,7 @@ class AgentAppApi(Resource):
|
||||
"role": args.role,
|
||||
}
|
||||
updated = AppService().update_app(app_model, args_dict)
|
||||
return _serialize_agent_app_detail(updated, current_user=current_user)
|
||||
return _serialize_agent_app_detail(updated)
|
||||
|
||||
@console_ns.response(204, "Agent app deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@ -558,34 +452,9 @@ class AgentAppApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/debug-conversation/refresh")
|
||||
class AgentDebugConversationRefreshApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Agent debug conversation refreshed",
|
||||
console_ns.models[AgentDebugConversationRefreshResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
debug_conversation_id = _agent_roster_service().refresh_agent_app_debug_conversation_id(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
account_id=current_user.id,
|
||||
)
|
||||
return AgentDebugConversationRefreshResponse(debug_conversation_id=debug_conversation_id).model_dump(
|
||||
mode="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/copy")
|
||||
class AgentAppCopyApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AgentAppCopyPayload.__name__])
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "Agent app copied successfully", console_ns.models[AgentAppDetailWithSite.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@ -596,88 +465,18 @@ class AgentAppCopyApi(Resource):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
args = AgentAppCopyPayload.model_validate(console_ns.payload or {})
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
copied_app = _agent_roster_service().duplicate_agent_app(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
account=current_user,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
role=args.role,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
return _serialize_agent_app_detail(copied_app, current_user=current_user), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-access")
|
||||
class AgentApiAccessApi(Resource):
|
||||
@console_ns.response(200, "Agent service API access", console_ns.models[AgentApiAccessResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return _serialize_agent_api_access(app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-enable")
|
||||
class AgentApiStatusApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AgentApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "Agent service API status updated", console_ns.models[AgentApiAccessResponse.__name__])
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, agent_id: UUID):
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
args = AgentApiStatusPayload.model_validate(console_ns.payload)
|
||||
app_model = AppService().update_app_api_status(app_model, args.enable_api)
|
||||
return _serialize_agent_api_access(app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-keys")
|
||||
class AgentApiKeyListApi(BaseApiKeyListResource):
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
token_prefix = "app-"
|
||||
|
||||
@console_ns.response(200, "Agent service API keys", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID) -> dict[str, object]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(app_model.id), tenant_id))
|
||||
|
||||
@console_ns.response(201, "Agent service API key created", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
def post(self, tenant_id: str, agent_id: UUID) -> tuple[dict[str, object], int]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(app_model.id), tenant_id)), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/api-keys/<uuid:api_key_id>")
|
||||
class AgentApiKeyApi(BaseApiKeyResource):
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
|
||||
@console_ns.response(204, "Agent service API key deleted")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_RELEASE_AND_VERSION)
|
||||
def delete(self, tenant_id: str, current_user: Account, agent_id: UUID, api_key_id: UUID) -> tuple[str, int]:
|
||||
app_model = _resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
self._delete_api_key(str(app_model.id), str(api_key_id), tenant_id, current_user)
|
||||
return "", 204
|
||||
return _serialize_agent_app_detail(copied_app), 201
|
||||
|
||||
|
||||
@console_ns.route("/agent/invite-options")
|
||||
@ -850,24 +649,3 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
version_id=str(version_id),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/versions/<uuid:version_id>/restore")
|
||||
class AgentRosterVersionRestoreApi(Resource):
|
||||
@console_ns.response(200, "Agent version restored", console_ns.models[AgentConfigSnapshotRestoreResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, current_user: Account, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotRestoreResponse,
|
||||
_agent_roster_service().restore_agent_version(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
account_id=current_user.id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -10,12 +10,8 @@ backend — drive data lives in the API's own DB/storage, served straight from
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -53,10 +49,6 @@ class AgentDriveFileByAgentQuery(BaseModel):
|
||||
key: str = Field(min_length=1, description="Drive key, e.g. tender-analyzer/SKILL.md")
|
||||
|
||||
|
||||
class AgentDriveSkillInspectQuery(BaseModel):
|
||||
node_id: str | None = Field(default=None, description="Workflow node ID (workflow composer variant)")
|
||||
|
||||
|
||||
class AgentDriveItemResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
@ -64,63 +56,12 @@ class AgentDriveItemResponse(ResponseModel):
|
||||
hash: str | None = None
|
||||
file_kind: str
|
||||
created_at: int | None = None
|
||||
is_skill: bool | None = None
|
||||
skill_metadata: str | None = None
|
||||
|
||||
|
||||
class AgentDriveListResponse(ResponseModel):
|
||||
items: list[AgentDriveItemResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDriveSkillItemResponse(ResponseModel):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None = None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
hash: str | None = None
|
||||
created_at: int | None = None
|
||||
|
||||
|
||||
class AgentDriveSkillListResponse(ResponseModel):
|
||||
items: list[AgentDriveSkillItemResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDriveSkillFileResponse(ResponseModel):
|
||||
path: str
|
||||
name: str
|
||||
type: str
|
||||
drive_key: str | None = None
|
||||
available_in_drive: bool
|
||||
|
||||
|
||||
class AgentDriveSkillMarkdownResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
class AgentDriveSkillInspectResponse(ResponseModel):
|
||||
path: str
|
||||
skill_md_key: str
|
||||
archive_key: str | None = None
|
||||
name: str
|
||||
description: str
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
hash: str | None = None
|
||||
created_at: int | None = None
|
||||
source: str
|
||||
files: list[AgentDriveSkillFileResponse] = Field(default_factory=list)
|
||||
file_tree: list[dict[str, Any]] = Field(default_factory=list)
|
||||
skill_md: AgentDriveSkillMarkdownResponse
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentDrivePreviewResponse(ResponseModel):
|
||||
key: str
|
||||
size: int | None = None
|
||||
@ -134,12 +75,7 @@ class AgentDriveDownloadResponse(ResponseModel):
|
||||
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentDriveDownloadResponse,
|
||||
AgentDriveListResponse,
|
||||
AgentDrivePreviewResponse,
|
||||
AgentDriveSkillInspectResponse,
|
||||
AgentDriveSkillListResponse,
|
||||
console_ns, AgentDriveListResponse, AgentDrivePreviewResponse, AgentDriveDownloadResponse
|
||||
)
|
||||
|
||||
|
||||
@ -160,13 +96,6 @@ def _handle(exc: AgentDriveError) -> tuple[dict[str, object], int]:
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
|
||||
|
||||
def _json_response(data: Mapping[str, Any]):
|
||||
return Response(
|
||||
response=json.dumps(data, ensure_ascii=False, separators=(",", ":")),
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
|
||||
|
||||
_WORKFLOW_APP_MODES = [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]
|
||||
|
||||
|
||||
@ -190,49 +119,6 @@ class AgentDriveListByAgentApi(Resource):
|
||||
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/skills")
|
||||
class AgentDriveSkillListByAgentApi(Resource):
|
||||
@console_ns.doc("list_agent_drive_skills_by_agent")
|
||||
@console_ns.doc(description="List drive-backed skills for an Agent App")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID"})
|
||||
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
try:
|
||||
items = AgentDriveService().list_skills(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
return {"items": items}
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/skills/<path:skill_path>/inspect")
|
||||
class AgentDriveSkillInspectByAgentApi(Resource):
|
||||
@console_ns.doc("inspect_agent_drive_skill_by_agent")
|
||||
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
|
||||
@console_ns.doc(params={"agent_id": "Agent ID", "skill_path": "Skill path/slug, e.g. tender-analyzer"})
|
||||
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, skill_path: str):
|
||||
resolve_agent_app_model(tenant_id=tenant_id, agent_id=agent_id)
|
||||
try:
|
||||
return _json_response(
|
||||
AgentDriveService().inspect_skill(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
skill_path=skill_path,
|
||||
)
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/drive/files/preview")
|
||||
class AgentDrivePreviewByAgentApi(Resource):
|
||||
@console_ns.doc("preview_agent_drive_file_by_agent")
|
||||
@ -296,61 +182,6 @@ class AgentDriveListApi(Resource):
|
||||
return {"items": [{k: v for k, v in item.items() if k != "file_id"} for item in items]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills")
|
||||
class AgentDriveSkillListApi(Resource):
|
||||
@console_ns.doc("list_agent_drive_skills")
|
||||
@console_ns.doc(description="List drive-backed skills for the bound agent")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentDriveListQuery)})
|
||||
@console_ns.response(200, "Drive skills", console_ns.models[AgentDriveSkillListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=_WORKFLOW_APP_MODES)
|
||||
def get(self, app_model: App):
|
||||
query = query_params_from_request(AgentDriveListQuery)
|
||||
agent_id = _resolve_agent_id(app_model, query.node_id)
|
||||
if not agent_id:
|
||||
return _agent_not_bound()
|
||||
try:
|
||||
items = AgentDriveService().list_skills(tenant_id=app_model.tenant_id, agent_id=agent_id)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
return {"items": items}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/skills/<path:skill_path>/inspect")
|
||||
class AgentDriveSkillInspectApi(Resource):
|
||||
@console_ns.doc("inspect_agent_drive_skill")
|
||||
@console_ns.doc(description="Inspect one drive-backed skill for slash-menu hover/detail UI")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"skill_path": "Skill path/slug, e.g. tender-analyzer",
|
||||
**query_params_from_model(AgentDriveSkillInspectQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Drive skill inspect view", console_ns.models[AgentDriveSkillInspectResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=_WORKFLOW_APP_MODES)
|
||||
def get(self, app_model: App, skill_path: str):
|
||||
query = query_params_from_request(AgentDriveSkillInspectQuery)
|
||||
agent_id = _resolve_agent_id(app_model, query.node_id)
|
||||
if not agent_id:
|
||||
return _agent_not_bound()
|
||||
try:
|
||||
return _json_response(
|
||||
AgentDriveService().inspect_skill(
|
||||
tenant_id=app_model.tenant_id,
|
||||
agent_id=agent_id,
|
||||
skill_path=skill_path,
|
||||
)
|
||||
)
|
||||
except AgentDriveError as exc:
|
||||
return _handle(exc)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/drive/files/preview")
|
||||
class AgentDrivePreviewApi(Resource):
|
||||
@console_ns.doc("preview_agent_drive_file")
|
||||
@ -401,8 +232,4 @@ __all__ = [
|
||||
"AgentDriveListByAgentApi",
|
||||
"AgentDrivePreviewApi",
|
||||
"AgentDrivePreviewByAgentApi",
|
||||
"AgentDriveSkillInspectApi",
|
||||
"AgentDriveSkillInspectByAgentApi",
|
||||
"AgentDriveSkillListApi",
|
||||
"AgentDriveSkillListByAgentApi",
|
||||
]
|
||||
|
||||
@ -14,7 +14,6 @@ from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.app_access import resolve_app_access_filter
|
||||
from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import (
|
||||
@ -79,6 +78,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
|
||||
DEFAULT_APP_LIST_MODE: AppListMode = "all"
|
||||
APP_LIST_PERMISSION_KEYS = frozenset({"app.preview", "app.acl.preview", "app.full_access"})
|
||||
|
||||
|
||||
class AppListBaseQuery(BaseModel):
|
||||
@ -167,6 +167,10 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
return normalized
|
||||
|
||||
|
||||
def _has_app_list_permission(permission_keys: Sequence[str]) -> bool:
|
||||
return any(permission_key in APP_LIST_PERMISSION_KEYS for permission_key in permission_keys)
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -608,12 +612,38 @@ class AppListApi(Resource):
|
||||
current_user_id,
|
||||
)
|
||||
if dify_config.RBAC_ENABLED:
|
||||
access_filter = resolve_app_access_filter(
|
||||
whitelist_scope = enterprise_rbac_service.RBACService.AppAccess.whitelist_resources(
|
||||
str(current_tenant_id),
|
||||
current_user_id,
|
||||
permissions=permissions,
|
||||
)
|
||||
access_filter.apply_to_params(params)
|
||||
can_manage_own_apps = "app.create_and_management" in permissions.workspace.permission_keys
|
||||
has_default_preview = _has_app_list_permission(
|
||||
permissions.app.default_permission_keys
|
||||
) or _has_app_list_permission(permissions.workspace.permission_keys)
|
||||
permission_app_ids: set[str] | None = None
|
||||
if not has_default_preview:
|
||||
permission_app_ids = {
|
||||
override.resource_id
|
||||
for override in permissions.app.overrides
|
||||
if _has_app_list_permission(override.permission_keys)
|
||||
}
|
||||
|
||||
if getattr(whitelist_scope, "unrestricted", False):
|
||||
accessible_app_ids = permission_app_ids
|
||||
else:
|
||||
accessible_app_ids = set(whitelist_scope.resource_ids)
|
||||
if permission_app_ids is not None:
|
||||
accessible_app_ids |= permission_app_ids
|
||||
elif has_default_preview:
|
||||
accessible_app_ids = None
|
||||
|
||||
if accessible_app_ids:
|
||||
params.accessible_app_ids = sorted(accessible_app_ids)
|
||||
params.include_own_apps = can_manage_own_apps
|
||||
elif accessible_app_ids is not None and can_manage_own_apps:
|
||||
params.is_created_by_me = True
|
||||
elif accessible_app_ids is not None:
|
||||
params.accessible_app_ids = []
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
|
||||
@ -138,9 +138,7 @@ class ChatMessageTextApi(Resource):
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
payload_data = dict(console_ns.payload or {})
|
||||
payload_data.setdefault("text", "")
|
||||
payload = TextToSpeechPayload.model_validate(payload_data)
|
||||
payload = TextToSpeechPayload.model_validate(console_ns.payload)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model,
|
||||
|
||||
@ -40,15 +40,12 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from services.agent.errors import AgentNotFoundError
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -194,11 +191,10 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_TEST_AND_RUN)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
def post(self, current_tenant_id: str, current_user: Account, app_model: App):
|
||||
return _create_chat_message(current_tenant_id=current_tenant_id, current_user=current_user, app_model=app_model)
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
return _create_chat_message(current_user=current_user, app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
|
||||
@ -219,12 +215,7 @@ class AgentChatMessageApi(Resource):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
|
||||
return _create_chat_message(
|
||||
current_tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
app_model=app_model,
|
||||
agent_id=str(agent_id),
|
||||
)
|
||||
return _create_chat_message(current_user=current_user, app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
@ -258,45 +249,11 @@ class AgentChatMessageStopApi(Resource):
|
||||
return _stop_chat_message(current_user_id=current_user_id, app_model=app_model, task_id=task_id)
|
||||
|
||||
|
||||
def _resolve_current_user_agent_debug_conversation_id(
|
||||
*, current_tenant_id: str, current_user: Account, app_model: App, agent_id: str | None
|
||||
) -> str:
|
||||
roster_service = AgentRosterService(db.session)
|
||||
if agent_id:
|
||||
return roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=current_tenant_id,
|
||||
agent_id=agent_id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
agent = roster_service.get_app_backing_agent(tenant_id=current_tenant_id, app_id=str(app_model.id))
|
||||
if agent is None:
|
||||
raise AgentNotFoundError()
|
||||
return roster_service.get_or_create_agent_app_debug_conversation_id(
|
||||
tenant_id=current_tenant_id,
|
||||
agent_id=agent.id,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
|
||||
|
||||
def _create_chat_message(
|
||||
*, current_user: Account, app_model: App, current_tenant_id: str | None = None, agent_id: str | None = None
|
||||
):
|
||||
def _create_chat_message(*, current_user: Account, app_model: App):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
debug_conversation_id = _resolve_current_user_agent_debug_conversation_id(
|
||||
current_tenant_id=current_tenant_id or app_model.tenant_id,
|
||||
current_user=current_user,
|
||||
app_model=app_model,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
if args_model.conversation_id and args_model.conversation_id != debug_conversation_id:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
args["conversation_id"] = debug_conversation_id
|
||||
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
|
||||
@ -53,7 +53,6 @@ from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
@ -187,11 +186,10 @@ class ChatMessageListApi(Resource):
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
@with_current_user
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
return _list_chat_messages(app_model=app_model, current_user=current_user)
|
||||
def get(self, app_model: App):
|
||||
return _list_chat_messages(app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/agent/<uuid:agent_id>/chat-messages")
|
||||
@ -207,11 +205,10 @@ class AgentChatMessageListApi(Resource):
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.APP, RBACPermission.APP_VIEW_LAYOUT)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, agent_id: UUID):
|
||||
def get(self, current_tenant_id: str, agent_id: UUID):
|
||||
app_model = resolve_agent_app_model(tenant_id=current_tenant_id, agent_id=agent_id)
|
||||
return _list_chat_messages(app_model=app_model, current_user=current_user)
|
||||
return _list_chat_messages(app_model=app_model)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
@ -393,24 +390,14 @@ class AgentMessageApi(Resource):
|
||||
return _get_message_detail(app_model=app_model, message_id=message_id)
|
||||
|
||||
|
||||
def _list_chat_messages(*, app_model: App, current_user: Account | None = None):
|
||||
def _list_chat_messages(*, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT and current_user is not None:
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=args.conversation_id,
|
||||
user=current_user,
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
else:
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.limit(1)
|
||||
)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -83,7 +83,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
|
||||
@ -26,7 +26,6 @@ from controllers.console.wraps import (
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -391,7 +390,6 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
|
||||
try:
|
||||
response = HitTestingService.external_retrieve(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query=payload.query,
|
||||
account=current_user,
|
||||
|
||||
@ -18,7 +18,6 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import resolve_account_fallback
|
||||
from models.account import Account
|
||||
@ -116,7 +115,6 @@ class DatasetsHitTestingBase:
|
||||
try:
|
||||
current_user, _ = resolve_account_fallback(current_user, current_tenant_id)
|
||||
response = HitTestingService.retrieve(
|
||||
session=db.session,
|
||||
dataset=dataset,
|
||||
query=cast(str, args.get("query")),
|
||||
account=current_user,
|
||||
|
||||
@ -222,7 +222,7 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -5,7 +5,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import TenantPluginPermission
|
||||
@ -18,9 +17,6 @@ def plugin_permission_required(
|
||||
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if dify_config.RBAC_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@ -169,7 +169,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -244,7 +244,7 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -326,7 +326,7 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
|
||||
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_CREATE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
@ -481,7 +481,7 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
|
||||
@ -469,7 +469,6 @@ class PluginDebuggingKeyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_DEBUG, resource_required=False)
|
||||
@plugin_permission_required(debug_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -615,7 +614,6 @@ class PluginUploadFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -636,7 +634,6 @@ class PluginUploadFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -656,7 +653,6 @@ class PluginUploadFromBundleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -677,7 +673,6 @@ class PluginInstallFromPkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -698,7 +693,6 @@ class PluginInstallFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -725,7 +719,6 @@ class PluginInstallFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -746,7 +739,6 @@ class PluginFetchMarketplacePkgApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -772,7 +764,6 @@ class PluginFetchManifestApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -793,7 +784,6 @@ class PluginFetchInstallTasksApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
@ -811,7 +801,6 @@ class PluginFetchInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, task_id: str):
|
||||
@ -827,7 +816,6 @@ class PluginDeleteInstallTaskApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str):
|
||||
@ -843,7 +831,6 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -859,7 +846,6 @@ class PluginDeleteInstallTaskItemApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, task_id: str, identifier: str):
|
||||
@ -876,7 +862,6 @@ class PluginUpgradeFromMarketplaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -899,7 +884,6 @@ class PluginUpgradeFromGithubApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -927,7 +911,6 @@ class PluginUninstallApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_INSTALL, resource_required=False)
|
||||
@plugin_permission_required(install_required=True)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
@ -1058,11 +1041,10 @@ class PluginChangeAutoUpgradeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
if not dify_config.RBAC_ENABLED and not user.is_admin_or_owner:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = ParserAutoUpgradeChange.model_validate(console_ns.payload)
|
||||
@ -1115,7 +1097,6 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.PLUGIN_PREFERENCES, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
# exclude one single plugin
|
||||
|
||||
@ -211,7 +211,7 @@ def _legacy_workspace_roles(
|
||||
name=role_name,
|
||||
description="",
|
||||
is_builtin=True,
|
||||
permission_keys=list(dict.fromkeys(_LEGACY_ROLE_PERMISSION_KEYS[role_name])),
|
||||
permission_keys=list(_LEGACY_ROLE_PERMISSION_KEYS[role_name]),
|
||||
role_tag="owner" if role_name == "owner" else "",
|
||||
)
|
||||
for role_name in ("owner", "admin", "editor", "normal", "dataset_operator")
|
||||
@ -244,6 +244,11 @@ def _legacy_workspace_roles(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Permission catalogs.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/role-permissions/catalog")
|
||||
class RBACWorkspaceCatalogApi(Resource):
|
||||
@login_required
|
||||
@ -370,6 +375,30 @@ class RBACRoleCopyApi(Resource):
|
||||
return _dump(role), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/roles/<uuid:role_id>/members")
|
||||
class RBACRoleMembersApi(Resource):
|
||||
@login_required
|
||||
@rbac_permission_required(
|
||||
RBACResourceScope.WORKSPACE, RBACPermission.WORKSPACE_ROLE_MANAGE, resource_required=False
|
||||
)
|
||||
@console_ns.response(200, "Success", console_ns.models[_RBACRoleAccountList.__name__])
|
||||
def get(self, role_id):
|
||||
tenant_id, account_id = _current_ids()
|
||||
return _dump(
|
||||
svc.RBACService.Roles.members(
|
||||
tenant_id,
|
||||
account_id,
|
||||
str(role_id),
|
||||
options=_pagination_options(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access policies (tenant-level permission sets).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _AccessPolicyCreateRequest(BaseModel):
|
||||
name: str
|
||||
resource_type: svc.RBACResourceType
|
||||
@ -759,6 +788,11 @@ class RBACDatasetMemberBindingsApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace-level access (Settings > Access Rules).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/rbac/workspace/apps/access-policy")
|
||||
class RBACWorkspaceAppMatrixApi(Resource):
|
||||
@login_required
|
||||
|
||||
@ -971,7 +971,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_USE, resource_required=False)
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.CREDENTIAL_MANAGE, resource_required=False)
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
@ -1070,7 +1070,6 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user: Account):
|
||||
@ -1126,7 +1125,6 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str):
|
||||
payload = MCPProviderUpdatePayload.model_validate(console_ns.payload or {})
|
||||
@ -1180,7 +1178,6 @@ class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
@ -1199,7 +1196,6 @@ class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
payload = MCPAuthPayload.model_validate(console_ns.payload or {})
|
||||
@ -1304,7 +1300,6 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@rbac_permission_required(RBACResourceScope.WORKSPACE, RBACPermission.MCP_MANAGE, resource_required=False)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider_id: str):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
|
||||
@ -31,7 +31,7 @@ from controllers.openapi._models import (
|
||||
AppDslExportQuery,
|
||||
AppDslExportResponse,
|
||||
AppDslImportPayload,
|
||||
AppInfo,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
@ -62,6 +62,7 @@ from controllers.openapi._models import (
|
||||
SessionListQuery,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
TaskStopResponse,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
@ -95,11 +96,12 @@ register_response_schema_models(
|
||||
openapi_ns,
|
||||
ErrorBody,
|
||||
EventStreamResponse,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfo,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppDslExportResponse,
|
||||
|
||||
@ -63,8 +63,6 @@ class OpenApiErrorCode(StrEnum):
|
||||
FILE_EXTENSION_BLOCKED = "file_extension_blocked"
|
||||
MEMBER_LIMIT_EXCEEDED = "member_limit_exceeded"
|
||||
MEMBER_LICENSE_EXCEEDED = "member_license_exceeded"
|
||||
HUMAN_INPUT_FORM_NOT_FOUND = "form_not_found"
|
||||
RECIPIENT_SURFACE_MISMATCH = "recipient_surface_mismatch"
|
||||
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
@ -241,16 +239,3 @@ class MemberLicenseExceeded(OpenApiError): # noqa: N818
|
||||
error_code = OpenApiErrorCode.MEMBER_LICENSE_EXCEEDED
|
||||
description = "Workspace member license capacity reached."
|
||||
hint = "Contact your workspace administrator to expand the license seat count."
|
||||
|
||||
|
||||
class HumanInputFormNotFound(OpenApiError): # noqa: N818
|
||||
code = 404
|
||||
error_code = OpenApiErrorCode.HUMAN_INPUT_FORM_NOT_FOUND
|
||||
description = "No human-input form matches this token. It may be wrong, expired, or already submitted."
|
||||
|
||||
|
||||
class RecipientSurfaceMismatch(OpenApiError): # noqa: N818
|
||||
code = 403
|
||||
error_code = OpenApiErrorCode.RECIPIENT_SURFACE_MISMATCH
|
||||
description = "This form's recipient can't be submitted via the OpenAPI surface."
|
||||
hint = "Action it through its channel (web app or console)."
|
||||
|
||||
@ -2,8 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any, Final, Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
@ -14,30 +13,6 @@ from models.model import AppMode
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class SupportedAppType(StrEnum):
|
||||
"""App types the ``app`` usage face (``get app``) lists and filters.
|
||||
|
||||
A curated subset of :class:`AppMode`: the real, user-facing app categories.
|
||||
Excludes runtime-only mode tags that are not standalone apps
|
||||
(``rag-pipeline`` is a knowledge ``Pipeline``; ``channel`` is unused) and the
|
||||
roster-owned ``agent`` type (surfaced through the roster, not this list).
|
||||
|
||||
Members reference ``AppMode.*.value`` so the subset relationship is
|
||||
type-checked: dropping a member from ``AppMode`` breaks this at import.
|
||||
This is the single source for the listable set — params, filters, and the
|
||||
generated CLI whitelist all derive from it.
|
||||
"""
|
||||
|
||||
COMPLETION = AppMode.COMPLETION.value
|
||||
CHAT = AppMode.CHAT.value
|
||||
ADVANCED_CHAT = AppMode.ADVANCED_CHAT.value
|
||||
WORKFLOW = AppMode.WORKFLOW.value
|
||||
AGENT_CHAT = AppMode.AGENT_CHAT.value
|
||||
|
||||
|
||||
SUPPORTED_APP_TYPES: Final[tuple[AppMode, ...]] = tuple(AppMode(t.value) for t in SupportedAppType)
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
@ -63,12 +38,18 @@ class PaginationEnvelope[T](BaseModel):
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
@ -89,14 +70,16 @@ class PermittedExternalAppsListResponse(BaseModel):
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfo(BaseModel):
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfo):
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
@ -304,13 +287,14 @@ class AppDescribeQuery(BaseModel):
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum of listable app types."""
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: UUIDStr
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: SupportedAppType | None = None
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
@ -360,7 +344,7 @@ class PermittedExternalAppsListQuery(BaseModel):
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: SupportedAppType | None = None
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
|
||||
@ -5,12 +5,11 @@ from typing import cast
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppDslExportQuery, AppDslExportResponse, AppDslImportPayload
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, App
|
||||
@ -38,11 +37,6 @@ class AppDslImportApi(Resource):
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(
|
||||
resource_type=RBACResourceScope.APP,
|
||||
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
|
||||
resource_required=False,
|
||||
),
|
||||
)
|
||||
@returns(200, Import, "Import completed")
|
||||
@returns(202, Import, "Import pending confirmation")
|
||||
@ -95,11 +89,6 @@ class AppDslImportConfirmApi(Resource):
|
||||
scope=Scope.WORKSPACE_WRITE,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(
|
||||
resource_type=RBACResourceScope.APP,
|
||||
scene=RBACPermission.APP_IMPORT_EXPORT_DSL,
|
||||
resource_required=False,
|
||||
),
|
||||
)
|
||||
@returns(200, Import, "Import confirmed")
|
||||
@returns(400, Import, "Import failed")
|
||||
@ -136,7 +125,6 @@ class AppDslExportApi(Resource):
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
|
||||
)
|
||||
@accepts(query=AppDslExportQuery)
|
||||
@returns(200, AppDslExportResponse, "Export successful")
|
||||
@ -167,7 +155,6 @@ class AppDslCheckDependenciesApi(Resource):
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
allowed_roles=frozenset({TenantAccountRole.EDITOR, TenantAccountRole.ADMIN, TenantAccountRole.OWNER}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_IMPORT_EXPORT_DSL),
|
||||
)
|
||||
@returns(200, CheckDependenciesResult, "Dependencies checked")
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
|
||||
@ -19,13 +19,12 @@ from werkzeug.exceptions import (
|
||||
|
||||
import services
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import AppRunRequest, TaskStopResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -137,10 +136,7 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@openapi_ns.response(200, "Run result (SSE stream)", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@accepts(body=AppRunRequest)
|
||||
def post(self, app_id: str, *, auth_data: AuthData, body: AppRunRequest):
|
||||
@ -171,10 +167,7 @@ class AppRunApi(Resource):
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(200, TaskStopResponse, description="Task stopped")
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
|
||||
@ -8,41 +8,33 @@ from typing import Any, cast
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.app_access import AppAccessFilter, resolve_app_access_filter
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
SUPPORTED_APP_TYPES,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
def _is_listable(app: App) -> bool:
|
||||
"""Whether the openapi app face exposes this app (curated, listable types only)."""
|
||||
return app.mode in SUPPORTED_APP_TYPES
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
@ -92,55 +84,54 @@ def parameters_payload(app: App) -> dict:
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
def build_app_describe_response(app: App, fields: set[str] | None) -> AppDescribeResponse:
|
||||
"""Public projection of an app (name / params / input schema) — never internal config."""
|
||||
want_info = fields is None or "info" in fields
|
||||
want_params = fields is None or "parameters" in fields
|
||||
want_schema = fields is None or "input_schema" in fields
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in (AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return AppDescribeResponse(info=info, parameters=parameters, input_schema=input_schema)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}),
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_VIEW_LAYOUT),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@returns(200, AppDescribeResponse, description="App description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# describe is UUID-only (workspace_id query param dropped in #37212).
|
||||
app = self._load(app_id)
|
||||
return build_app_describe_response(app, query.fields)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
@ -161,57 +152,45 @@ class AppListApi(Resource):
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
# Compute RBAC-accessible app IDs when RBAC is enabled and the caller is an account.
|
||||
# ``None`` means unrestricted (caller can see all apps in the workspace);
|
||||
# an empty set or list means the caller has no accessible apps.
|
||||
# End-users bypass RBAC here — their access is controlled by scope upstream.
|
||||
apply_rbac_filter = (
|
||||
dify_config.RBAC_ENABLED and auth_data.caller_kind != "end_user" and auth_data.account_id is not None
|
||||
)
|
||||
access_filter = AppAccessFilter.unrestricted()
|
||||
if apply_rbac_filter:
|
||||
access_filter = resolve_app_access_filter(workspace_id, str(auth_data.account_id))
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
if not _is_listable(app):
|
||||
return empty
|
||||
# Apply RBAC visibility to the UUID fast-path the same way the service
|
||||
# layer does for paginated queries (id in accessible set OR own app).
|
||||
if apply_rbac_filter and not access_filter.is_app_accessible(
|
||||
str(app.id), str(app.maintainer) if app.maintainer else None, str(auth_data.account_id)
|
||||
):
|
||||
return empty
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag, db.session)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
if apply_rbac_filter:
|
||||
access_filter.apply_to_params(params)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params, db.session)
|
||||
if pagination is None:
|
||||
return empty
|
||||
@ -226,12 +205,13 @@ class AppListApi(Resource):
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
if _is_listable(r)
|
||||
]
|
||||
|
||||
env = AppListResponse(
|
||||
|
||||
@ -8,18 +8,14 @@ EE blueprint chain so this module is unreachable there.
|
||||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.apps import build_app_describe_response
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
@ -71,7 +67,9 @@ class PermittedExternalAppsListApi(Resource):
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
@ -84,20 +82,3 @@ class PermittedExternalAppsListApi(Resource):
|
||||
data=items,
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps/<string:app_id>/describe")
|
||||
class PermittedExternalAppDescribeApi(Resource):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
@returns(200, AppDescribeResponse, description="Permitted external app description")
|
||||
@accepts(query=AppDescribeQuery)
|
||||
def get(self, app_id: str, *, auth_data: AuthData, query: AppDescribeQuery):
|
||||
# App already loaded and ACL-checked by the external_sso pipeline; project it.
|
||||
app = auth_data.app
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
return build_app_describe_response(app, query.fields)
|
||||
|
||||
@ -3,11 +3,9 @@ from __future__ import annotations
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
HAS_RBAC,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WEBAPP_RUN_SCOPED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
WORKSPACE_SCOPED,
|
||||
)
|
||||
@ -27,7 +25,6 @@ from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_rbac_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
@ -50,9 +47,8 @@ account_pipeline = AuthPipeline(
|
||||
When(WORKSPACE_SCOPED, then=check_workspace_member),
|
||||
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(HAS_RBAC, then=check_rbac_permission),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED & WEBAPP_RUN_SCOPED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE & WEBAPP_RUN_SCOPED, then=check_private_app_permission),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -50,11 +50,8 @@ EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WEBAPP_RUN_SCOPED = request_cond(lambda ctx: ctx.scope == Scope.APPS_RUN)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
HAS_RBAC = request_cond(lambda ctx: ctx.rbac is not None)
|
||||
|
||||
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
|
||||
# from the app) or an explicit workspace-membership path (tenant from request).
|
||||
|
||||
@ -8,7 +8,6 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from core.rbac import RBACPermission, RBACResourceScope
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.model import App, EndUser
|
||||
@ -36,14 +35,6 @@ class ExternalIdentity(BaseModel):
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RBACRequirement(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
resource_type: RBACResourceScope
|
||||
scene: RBACPermission
|
||||
resource_required: bool = True
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
@ -52,7 +43,6 @@ class RequestContext(BaseModel):
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
rbac: RBACRequirement | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
@ -69,7 +59,6 @@ class AuthData(BaseModel):
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
rbac: RBACRequirement | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
|
||||
@ -21,7 +21,6 @@ from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RBACRequirement,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
@ -60,7 +59,6 @@ class AuthPipeline:
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
@ -68,7 +66,6 @@ class AuthPipeline:
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
@ -80,7 +77,6 @@ class AuthPipeline:
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
@ -133,7 +129,6 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
@ -141,7 +136,6 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
@ -151,7 +145,6 @@ class PipelineRouter:
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
@ -159,7 +152,6 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
@ -170,7 +162,6 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
rbac: RBACRequirement | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
@ -184,7 +175,6 @@ class PipelineRouter:
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
return decorated
|
||||
@ -202,7 +192,6 @@ class PipelineRouter:
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
rbac: RBACRequirement | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
@ -246,7 +235,6 @@ class PipelineRouter:
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
rbac=rbac,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -74,13 +74,12 @@ def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
match raw:
|
||||
case None:
|
||||
return None
|
||||
case SubjectType():
|
||||
return raw
|
||||
case str():
|
||||
return SubjectType(raw)
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -3,8 +3,6 @@ from __future__ import annotations
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.wraps import enforce_rbac_access
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
@ -40,9 +38,6 @@ def check_workspace_mismatch(data: AuthData) -> None:
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if dify_config.RBAC_ENABLED and data.rbac is not None:
|
||||
# fine-grained permission check is performed by RBAC
|
||||
return
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
@ -51,27 +46,6 @@ def check_workspace_role(data: AuthData) -> None:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_rbac_permission(data: AuthData) -> None:
|
||||
req = data.rbac
|
||||
if req is None:
|
||||
return
|
||||
if not dify_config.RBAC_ENABLED:
|
||||
return
|
||||
# Only account callers are subject to RBAC; end_user access is scope-controlled.
|
||||
if data.caller_kind != "account":
|
||||
return
|
||||
if data.account_id is None or data.tenant is None:
|
||||
raise Forbidden("rbac context missing")
|
||||
enforce_rbac_access(
|
||||
tenant_id=str(data.tenant.id),
|
||||
account_id=str(data.account_id),
|
||||
resource_type=req.resource_type,
|
||||
scene=req.scene,
|
||||
resource_required=req.resource_required,
|
||||
path_args=dict(data.path_params),
|
||||
)
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
|
||||
@ -12,21 +12,16 @@ import logging
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._contract import accepts, returns
|
||||
from controllers.openapi._errors import HumanInputFormNotFound, RecipientSurfaceMismatch
|
||||
from controllers.openapi._models import FormSubmitResponse, HumanInputFormDefinitionResponse
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from core.workflow.human_input_policy import (
|
||||
HumanInputSurface,
|
||||
is_recipient_type_allowed_for_surface,
|
||||
)
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
@ -52,37 +47,31 @@ def _jsonify_form_definition(form) -> Response:
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise HumanInputFormNotFound()
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise RecipientSurfaceMismatch()
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition", openapi_ns.models[HumanInputFormDefinitionResponse.__name__])
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, _caller, _caller_kind = auth_data.require_app_context()
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise HumanInputFormNotFound()
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
@returns(200, FormSubmitResponse, description="Form submitted")
|
||||
@accepts(body=HumanInputFormSubmitPayload)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData, body: HumanInputFormSubmitPayload):
|
||||
@ -91,7 +80,7 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise HumanInputFormNotFound()
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
@ -117,6 +106,6 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise HumanInputFormNotFound()
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return FormSubmitResponse()
|
||||
|
||||
@ -19,10 +19,9 @@ from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import EventStreamResponse
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.common.wraps import RBACPermission, RBACResourceScope
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, RBACRequirement
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
@ -47,10 +46,7 @@ class WorkflowEventsQuery(BaseModel):
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(WorkflowEventsQuery))
|
||||
@openapi_ns.response(200, "SSE event stream", openapi_ns.models[EventStreamResponse.__name__])
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_RUN,
|
||||
rbac=RBACRequirement(resource_type=RBACResourceScope.APP, scene=RBACPermission.APP_TEST_AND_RUN),
|
||||
)
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
@ -10,11 +9,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentScope, AgentSource, AgentStatus
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
@ -34,40 +29,6 @@ class AppMetaResponse(ResponseModel):
|
||||
register_response_schema_models(service_api_ns, Parameters, AppMetaResponse, AppInfoResponse)
|
||||
|
||||
|
||||
def _get_agent_app_feature_dict_and_user_input_form(app_model: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
app_model_config = app_model.app_model_config
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict()) if app_model_config is not None else {}
|
||||
|
||||
agent = db.session.scalar(
|
||||
select(Agent)
|
||||
.where(
|
||||
Agent.tenant_id == app_model.tenant_id,
|
||||
Agent.app_id == app_model.id,
|
||||
Agent.scope == AgentScope.ROSTER,
|
||||
Agent.source == AgentSource.AGENT_APP,
|
||||
Agent.status == AgentStatus.ACTIVE,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if agent is None or not agent.active_config_snapshot_id:
|
||||
raise AppUnavailableError()
|
||||
|
||||
snapshot = db.session.scalar(
|
||||
select(AgentConfigSnapshot)
|
||||
.where(
|
||||
AgentConfigSnapshot.tenant_id == app_model.tenant_id,
|
||||
AgentConfigSnapshot.agent_id == agent.id,
|
||||
AgentConfigSnapshot.id == agent.active_config_snapshot_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if snapshot is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
|
||||
return features_dict, agent_app_variables_to_user_input_form(agent_soul.app_variables)
|
||||
|
||||
|
||||
@service_api_ns.route("/parameters")
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
@ -100,16 +61,12 @@ class AppParameterApi(Resource):
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
"""
|
||||
features_dict: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
if app_model.mode == AppMode.AGENT:
|
||||
features_dict, user_input_form = _get_agent_app_feature_dict_and_user_input_form(app_model)
|
||||
elif app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -22,7 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import create_session, session_factory
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
with create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
@ -204,8 +204,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before workflow execution so a checked-out DB connection
|
||||
# is not held for the lifetime of the graph run.
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
@ -370,7 +368,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with create_session() as session, session.begin():
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
|
||||
@ -21,7 +21,6 @@ from core.app.app_config.entities import (
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps.agent_app.app_variable_projection import agent_app_variables_to_user_input_form
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
@ -99,7 +98,8 @@ class AgentAppConfigManager(BaseAppConfigManager):
|
||||
# pipeline's bookkeeping (token counting, persistence).
|
||||
base["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
base["pre_prompt"] = agent_soul.prompt.system_prompt or ""
|
||||
base["user_input_form"] = agent_app_variables_to_user_input_form(agent_soul.app_variables)
|
||||
# Agent App takes the user message directly; no completion-style inputs form.
|
||||
base.setdefault("user_input_form", [])
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from models.agent_config_entities import AppVariableConfig
|
||||
|
||||
|
||||
def agent_app_variables_to_user_input_form(app_variables: Sequence[AppVariableConfig]) -> list[dict[str, Any]]:
|
||||
"""Project Agent Soul app variables into the legacy service-API parameter form."""
|
||||
|
||||
user_input_form: list[dict[str, Any]] = []
|
||||
for variable in app_variables:
|
||||
form_type = _form_type_for_agent_variable(variable.type)
|
||||
form_item: dict[str, Any] = {
|
||||
"label": variable.name,
|
||||
"variable": variable.name,
|
||||
"required": variable.required,
|
||||
}
|
||||
if variable.default is not None:
|
||||
form_item["default"] = variable.default
|
||||
user_input_form.append({form_type: form_item})
|
||||
return user_input_form
|
||||
|
||||
|
||||
def _form_type_for_agent_variable(variable_type: str) -> str:
|
||||
normalized = variable_type.strip().lower()
|
||||
if normalized in {"number", "integer", "float"}:
|
||||
return "number"
|
||||
if normalized in {"boolean", "bool"}:
|
||||
return "checkbox"
|
||||
if normalized in {"paragraph", "long_text", "multiline"}:
|
||||
return "paragraph"
|
||||
return "text-input"
|
||||
|
||||
|
||||
__all__ = ["agent_app_variables_to_user_input_form"]
|
||||
@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
@ -47,10 +47,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(app_stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -188,18 +185,14 @@ class AgentChatAppRunner(AppRunner):
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
with create_session() as session:
|
||||
conversation_result = session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
message_result = session.scalar(msg_stmt)
|
||||
if message_result is not None:
|
||||
session.expunge(message_result)
|
||||
session.expunge(conversation_result)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
|
||||
@ -11,7 +11,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
@ -47,10 +46,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -220,8 +216,6 @@ class ChatAppRunner(AppRunner):
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
|
||||
# is not held for the lifetime of the provider response.
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
||||
@ -51,11 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import (
|
||||
load_form_dispositions_by_form_id,
|
||||
)
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import (
|
||||
FormDisposition,
|
||||
HumanInputSurface,
|
||||
enrich_human_input_pause_reasons,
|
||||
resolve_human_input_pause_reason_inputs,
|
||||
@ -343,14 +340,13 @@ class WorkflowResponseConverter:
|
||||
human_input_form_ids = [reason.form_id for reason in resolved_reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
dispositions_by_form_id: dict[str, FormDisposition] = {}
|
||||
form_token_by_form_id: dict[str, str] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(
|
||||
HumanInputForm.id,
|
||||
HumanInputForm.expiration_time,
|
||||
HumanInputForm.form_definition,
|
||||
).where(HumanInputForm.id.in_(human_input_form_ids))
|
||||
hitl_surface = _INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from)
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time, form_definition in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
@ -359,17 +355,17 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
dispositions_by_form_id = load_form_dispositions_by_form_id(
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=hitl_surface,
|
||||
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
dispositions_by_form_id=dispositions_by_form_id,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
@ -383,7 +379,6 @@ class WorkflowResponseConverter:
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
||||
disposition = dispositions_by_form_id.get(reason.form_id)
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
@ -396,8 +391,7 @@ class WorkflowResponseConverter:
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False),
|
||||
form_token=disposition.form_token if disposition else None,
|
||||
approval_channels=list(disposition.approval_channels) if disposition else [],
|
||||
form_token=form_token_by_form_id.get(reason.form_id),
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
|
||||
@ -14,8 +14,8 @@ from core.app.app_config.easy_ui_based_app.model_config.converter import ModelCo
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
from core.app.apps.completion.workflow_runner import CompletionWorkflowRunner
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
@ -218,7 +218,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = CompletionWorkflowRunner()
|
||||
runner = CompletionAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
|
||||
194
api/core/app/apps/completion/app_runner.py
Normal file
194
api/core/app/apps/completion/app_runner.py
Normal file
@ -0,0 +1,194 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from models.model import App, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionAppRunner(AppRunner):
|
||||
"""
|
||||
Completion Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||
):
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
image_detail_config = (
|
||||
application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
application_generate_entity.file_upload_config
|
||||
and application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
context_files: list[File] = []
|
||||
if app_config.dataset and app_config.dataset.dataset_ids:
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_config = app_config.dataset
|
||||
if dataset_config and dataset_config.retrieve_config.query_variable:
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
context, retrieved_files = dataset_retrieval.retrieve(
|
||||
app_id=app_record.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
config=dataset_config,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source
|
||||
if app_config.additional_features
|
||||
else False,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=application_generate_entity.model_conf.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
message_id=message.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
)
|
||||
@ -1,148 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from core.prompt.utils.prompt_message_util import SavedPrompt
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
_LLM_TEXT_SELECTOR_PREFIX = ("llm", "text")
|
||||
|
||||
|
||||
class CompletionGraphEventAdapter:
|
||||
"""Translate runtime workflow events into legacy Completion queue events."""
|
||||
|
||||
_application_generate_entity: CompletionAppGenerateEntity
|
||||
_queue_manager: AppQueueManager
|
||||
_answer: str
|
||||
_usage: LLMUsage
|
||||
_saved_prompt: list[SavedPrompt]
|
||||
_chunk_index: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._answer = ""
|
||||
self._usage = LLMUsage.empty_usage()
|
||||
self._saved_prompt = []
|
||||
self._chunk_index = 0
|
||||
|
||||
def handle_event(self, event: GraphEngineEvent) -> None:
|
||||
match event:
|
||||
case NodeRunStreamChunkEvent():
|
||||
self._handle_stream_chunk(event)
|
||||
case NodeRunRetrieverResourceEvent():
|
||||
self._handle_retriever_resource(event)
|
||||
case NodeRunSucceededEvent():
|
||||
self._handle_node_succeeded(event)
|
||||
case NodeRunFailedEvent() | NodeRunExceptionEvent():
|
||||
self._publish_error(event.error or event.node_run_result.error or "Node failed")
|
||||
case GraphRunSucceededEvent():
|
||||
self._publish_message_end(event.outputs)
|
||||
case GraphRunFailedEvent():
|
||||
self._publish_error(event.error)
|
||||
case GraphRunAbortedEvent():
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
case _:
|
||||
return
|
||||
|
||||
def _handle_stream_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
if tuple(event.selector)[:2] != _LLM_TEXT_SELECTOR_PREFIX:
|
||||
return
|
||||
if event.is_final and not event.chunk:
|
||||
return
|
||||
|
||||
self._answer += event.chunk
|
||||
self._queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
model=self._application_generate_entity.model_conf.model,
|
||||
prompt_messages=[],
|
||||
delta=LLMResultChunkDelta(
|
||||
index=self._chunk_index,
|
||||
message=AssistantPromptMessage(content=event.chunk),
|
||||
),
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._chunk_index += 1
|
||||
|
||||
def _handle_retriever_resource(self, event: NodeRunRetrieverResourceEvent) -> None:
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=[
|
||||
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
|
||||
],
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
if event.node_type != BuiltinNodeTypes.LLM and event.node_id != "llm":
|
||||
return
|
||||
|
||||
result = event.node_run_result
|
||||
text = result.outputs.get("text")
|
||||
if isinstance(text, str):
|
||||
self._answer = text
|
||||
self._usage = result.llm_usage
|
||||
|
||||
prompts = result.process_data.get("prompts")
|
||||
if isinstance(prompts, list):
|
||||
self._saved_prompt = cast(list[SavedPrompt], prompts)
|
||||
|
||||
def _publish_message_end(self, outputs: Mapping[str, object]) -> None:
|
||||
result = outputs.get("result")
|
||||
if isinstance(result, str) and not self._answer:
|
||||
self._answer = result
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self._application_generate_entity.model_conf.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=self._answer),
|
||||
usage=self._usage,
|
||||
),
|
||||
saved_prompt=self._saved_prompt,
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _publish_error(self, error: Any) -> None:
|
||||
self._queue_manager.publish(
|
||||
QueueErrorEvent(error=ValueError(str(error))),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
@ -1,64 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from models.model import App, AppMode
|
||||
from services.workflow.workflow_converter import WorkflowConverter, WorkflowGraph
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RuntimeCompletionWorkflow:
|
||||
workflow_id: str
|
||||
root_node_id: str
|
||||
graph_dict: WorkflowGraph
|
||||
|
||||
|
||||
class RuntimeCompletionWorkflowBuilder:
|
||||
"""Build the transient WorkflowEntry graph used by Completion execution."""
|
||||
|
||||
def __init__(self, workflow_converter: WorkflowConverter | None = None) -> None:
|
||||
self._workflow_converter = workflow_converter or WorkflowConverter()
|
||||
|
||||
def build(self, *, app_model: App, app_config: CompletionAppConfig) -> RuntimeCompletionWorkflow:
|
||||
graph, _ = self._workflow_converter.build_graph_from_app_config(
|
||||
app_model=app_model,
|
||||
app_config=app_config,
|
||||
target_app_mode=AppMode.WORKFLOW,
|
||||
)
|
||||
self._route_external_data_query_to_sys_query(graph)
|
||||
return RuntimeCompletionWorkflow(
|
||||
workflow_id=f"completion-runtime-{uuid4()}",
|
||||
root_node_id="start",
|
||||
graph_dict=graph,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _route_external_data_query_to_sys_query(graph: WorkflowGraph) -> None:
|
||||
"""Preserve Completion API-based variable behavior in the runtime graph."""
|
||||
for node in graph["nodes"]:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") != BuiltinNodeTypes.HTTP_REQUEST:
|
||||
continue
|
||||
|
||||
body = data.get("body")
|
||||
if not isinstance(body, dict) or body.get("type") != "json":
|
||||
continue
|
||||
|
||||
raw_body_data = body.get("data")
|
||||
if not isinstance(raw_body_data, str):
|
||||
continue
|
||||
|
||||
try:
|
||||
body_data: dict[str, Any] = json.loads(raw_body_data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
params = body_data.get("params")
|
||||
if not isinstance(params, dict) or params.get("query") != "":
|
||||
continue
|
||||
|
||||
params["query"] = "{{#sys.query#}}"
|
||||
body["data"] = json.dumps(body_data)
|
||||
@ -1,250 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
from core.app.apps.completion.graph_event_adapter import CompletionGraphEventAdapter
|
||||
from core.app.apps.completion.runtime_workflow_builder import RuntimeCompletionWorkflowBuilder
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.workflow_app_runner import init_graph
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, UserFrom
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.node_runtime import DIFY_BEFORE_LLM_INVOKE_KEY
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_hosting_provider import hosting_configuration
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.command_channels import RedisChannel
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from models.model import App, Message
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ModeratedCompletionInputs:
|
||||
stopped: bool
|
||||
inputs: Mapping[str, Any]
|
||||
query: str
|
||||
|
||||
|
||||
class CompletionWorkflowRunner(AppRunner):
|
||||
"""Run Completion through a transient WorkflowEntry graph."""
|
||||
|
||||
_runtime_workflow_builder: RuntimeCompletionWorkflowBuilder
|
||||
|
||||
def __init__(self, runtime_workflow_builder: RuntimeCompletionWorkflowBuilder | None = None) -> None:
|
||||
self._runtime_workflow_builder = runtime_workflow_builder or RuntimeCompletionWorkflowBuilder()
|
||||
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
) -> None:
|
||||
app_config = cast(CompletionAppConfig, application_generate_entity.app_config)
|
||||
app_record = self._get_app(app_config.app_id)
|
||||
|
||||
moderation_result = self._run_input_moderation(
|
||||
app_record=app_record,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
)
|
||||
if moderation_result.stopped:
|
||||
return
|
||||
|
||||
runtime_workflow = self._runtime_workflow_builder.build(app_model=app_record, app_config=app_config)
|
||||
variable_pool = self._build_variable_pool(
|
||||
application_generate_entity=application_generate_entity,
|
||||
message=message,
|
||||
workflow_id=runtime_workflow.workflow_id,
|
||||
root_node_id=runtime_workflow.root_node_id,
|
||||
inputs=moderation_result.inputs,
|
||||
query=moderation_result.query,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
user_from = self._resolve_user_from(application_generate_entity)
|
||||
extra_context: dict[str, Any] = {}
|
||||
if self._should_check_hosting_moderation(application_generate_entity):
|
||||
extra_context[DIFY_BEFORE_LLM_INVOKE_KEY] = self._build_hosting_moderation_hook(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
|
||||
graph = init_graph(
|
||||
app_id=app_config.app_id,
|
||||
graph_config=runtime_workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
user_from=user_from,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_id=runtime_workflow.workflow_id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
root_node_id=runtime_workflow.root_node_id,
|
||||
trace_session_id=application_generate_entity.extras.get("trace_session_id"),
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
|
||||
queue_manager.graph_runtime_state = graph_runtime_state
|
||||
command_channel = RedisChannel(redis_client, f"workflow:{application_generate_entity.task_id}:commands")
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
workflow_id=runtime_workflow.workflow_id,
|
||||
graph_config=runtime_workflow.graph_dict,
|
||||
graph=graph,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
adapter = CompletionGraphEventAdapter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
for event in workflow_entry.run():
|
||||
adapter.handle_event(event)
|
||||
|
||||
def _get_app(self, app_id: str) -> App:
|
||||
app_record = db.session.scalar(select(App).where(App.id == app_id))
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
return app_record
|
||||
|
||||
def _run_input_moderation(
|
||||
self,
|
||||
*,
|
||||
app_record: App,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
) -> ModeratedCompletionInputs:
|
||||
app_config = cast(CompletionAppConfig, application_generate_entity.app_config)
|
||||
prompt_messages, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=application_generate_entity.inputs,
|
||||
files=application_generate_entity.files,
|
||||
query=application_generate_entity.query,
|
||||
image_detail_config=self._resolve_image_detail_config(application_generate_entity),
|
||||
)
|
||||
|
||||
try:
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationError as exc:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(exc),
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return ModeratedCompletionInputs(
|
||||
stopped=True,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
)
|
||||
|
||||
return ModeratedCompletionInputs(stopped=False, inputs=inputs, query=query)
|
||||
|
||||
def _build_hosting_moderation_hook(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
) -> Callable[[Sequence[PromptMessage]], None]:
|
||||
def check(prompt_messages: Sequence[PromptMessage]) -> None:
|
||||
if self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=list(prompt_messages),
|
||||
):
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
return check
|
||||
|
||||
def _should_check_hosting_moderation(self, application_generate_entity: CompletionAppGenerateEntity) -> bool:
|
||||
moderation_config = hosting_configuration.moderation_config
|
||||
openai_provider_name = f"{DEFAULT_PLUGIN_ID}/openai/openai"
|
||||
hosting_provider = hosting_configuration.provider_map.get(openai_provider_name)
|
||||
if not (
|
||||
moderation_config
|
||||
and moderation_config.enabled is True
|
||||
and hosting_provider
|
||||
and hosting_provider.enabled is True
|
||||
and hosting_provider.credentials is not None
|
||||
):
|
||||
return False
|
||||
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = getattr(model_config, "provider_model_bundle", None)
|
||||
configuration = getattr(provider_model_bundle, "configuration", None)
|
||||
using_provider_type = getattr(configuration, "using_provider_type", None)
|
||||
return (
|
||||
using_provider_type == ProviderType.SYSTEM
|
||||
and getattr(model_config, "provider", None) in moderation_config.providers
|
||||
)
|
||||
|
||||
def _build_variable_pool(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
message: Message,
|
||||
workflow_id: str,
|
||||
root_node_id: str,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> VariablePool:
|
||||
variable_pool = VariablePool()
|
||||
system_inputs = build_system_variables(
|
||||
files=application_generate_entity.files,
|
||||
user_id=application_generate_entity.user_id,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_execution_id=application_generate_entity.task_id,
|
||||
timestamp=int(time.time()),
|
||||
query=query,
|
||||
conversation_id=getattr(message, "conversation_id", None),
|
||||
)
|
||||
add_variables_to_pool(
|
||||
variable_pool,
|
||||
build_bootstrap_variables(system_variables=system_inputs, environment_variables=[]),
|
||||
)
|
||||
add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs)
|
||||
return variable_pool
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_from(application_generate_entity: CompletionAppGenerateEntity) -> UserFrom:
|
||||
if application_generate_entity.invoke_from.runs_as_account():
|
||||
return UserFrom.ACCOUNT
|
||||
return UserFrom.END_USER
|
||||
|
||||
@staticmethod
|
||||
def _resolve_image_detail_config(
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
) -> ImagePromptMessageContent.DETAIL:
|
||||
file_upload_config = application_generate_entity.file_upload_config
|
||||
if file_upload_config and file_upload_config.image_config:
|
||||
return file_upload_config.image_config.detail or ImagePromptMessageContent.DETAIL.LOW
|
||||
return ImagePromptMessageContent.DETAIL.LOW
|
||||
@ -11,7 +11,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
@ -48,6 +47,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
if self._app_mode == AppMode.ADVANCED_CHAT.value:
|
||||
return
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -3,7 +3,6 @@ import time
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
@ -15,12 +14,12 @@ from core.app.entities.app_invoke_entities import (
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import create_session
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
@ -84,24 +83,22 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
user_id = None
|
||||
with create_session() as session:
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
pipeline = db.session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
session.expunge(pipeline)
|
||||
session.expunge(workflow)
|
||||
db.session.close()
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
@ -211,12 +208,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = session.scalar(
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.limit(1)
|
||||
@ -301,11 +298,11 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
with create_session() as session, session.begin():
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
session.add(document)
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -42,3 +43,6 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
|
||||
@ -88,60 +88,6 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_graph(
|
||||
*,
|
||||
app_id: str,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
root_node_id: str | None = None,
|
||||
trace_session_id: str | None = None,
|
||||
call_depth: int = 0,
|
||||
extra_context: Mapping[str, Any] | None = None,
|
||||
) -> Graph:
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
trace_session_id=trace_session_id,
|
||||
extra_context=extra_context,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
run_context=run_context,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
if root_node_id is None:
|
||||
root_node_id = get_default_root_node_id(graph_config)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
self,
|
||||
@ -177,18 +123,48 @@ class WorkflowBasedAppRunner:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
return init_graph(
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# Create explicit graph init context for Graph.init.
|
||||
run_context = build_dify_run_context(
|
||||
tenant_id=tenant_id or "",
|
||||
app_id=self._app_id,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
root_node_id=root_node_id,
|
||||
trace_session_id=trace_session_id,
|
||||
)
|
||||
graph_init_context = DifyGraphInitContext(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
run_context=run_context,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
if root_node_id is None:
|
||||
root_node_id = get_default_root_node_id(graph_config)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.prompt.utils.prompt_message_util import SavedPrompt
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
@ -253,7 +252,6 @@ class QueueMessageEndEvent(AppQueueEvent):
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_END
|
||||
llm_result: LLMResult | None = None
|
||||
saved_prompt: list[SavedPrompt] | None = None
|
||||
|
||||
|
||||
class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
|
||||
@ -5,7 +5,6 @@ from typing import Any, Literal
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.prompt.utils.prompt_message_util import SavedPrompt
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
@ -44,7 +43,6 @@ class EasyUITaskState(TaskState):
|
||||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
saved_prompt: list[SavedPrompt] | None = None
|
||||
|
||||
|
||||
class WorkflowTaskState(TaskState):
|
||||
@ -290,7 +288,6 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
actions: Sequence[UserActionConfig] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
approval_channels: list[str] = Field(default_factory=list)
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
||||
|
||||
@ -314,7 +311,6 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
actions: Sequence[UserActionConfig] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
approval_channels: list[str] = Field(default_factory=list)
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@ -329,7 +325,6 @@ class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
approval_channels=data.approval_channels,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
@ -277,9 +277,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline[EasyUIAppGenerat
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
if event.llm_result:
|
||||
self._task_state.llm_result = event.llm_result
|
||||
saved_prompt = getattr(event, "saved_prompt", None)
|
||||
if saved_prompt is not None:
|
||||
self._task_state.saved_prompt = saved_prompt
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
@ -396,11 +393,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline[EasyUIAppGenerat
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation {self._conversation_id} not found")
|
||||
|
||||
saved_prompt = self._task_state.saved_prompt
|
||||
if saved_prompt is None:
|
||||
saved_prompt = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
saved_prompt = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
object.__setattr__(message, "message", saved_prompt)
|
||||
message.message_tokens = usage.prompt_tokens
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
|
||||
@ -14,7 +14,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from http import HTTPStatus
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
@ -294,27 +293,28 @@ class StreamableHTTPTransport:
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
match response.status_code:
|
||||
case HTTPStatus.ACCEPTED:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
case HTTPStatus.NO_CONTENT:
|
||||
logger.debug("Received 204 No Content")
|
||||
return
|
||||
case HTTPStatus.NOT_FOUND:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
error_msg = (
|
||||
f"MCP server URL returned 404 Not Found: {self.url} "
|
||||
"— verify the server URL is correct and the server is running"
|
||||
if is_initialization
|
||||
else "Session terminated by server"
|
||||
)
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
message=error_msg,
|
||||
)
|
||||
return
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 204:
|
||||
logger.debug("Received 204 No Content")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
error_msg = (
|
||||
f"MCP server URL returned 404 Not Found: {self.url} "
|
||||
"— verify the server URL is correct and the server is running"
|
||||
if is_initialization
|
||||
else "Session terminated by server"
|
||||
)
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
message=error_msg,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -12,19 +13,10 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db.session_factory import create_session
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account, TenantAccountJoin
|
||||
from models.model import (
|
||||
App,
|
||||
AppMode,
|
||||
AppModelConfig,
|
||||
AppModelConfigDict,
|
||||
EndUser,
|
||||
load_annotation_reply_config,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@ -38,18 +30,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = cls._get_workflow(app)
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config_dict = cls._get_app_model_config_dict(app)
|
||||
if app_model_config_dict is None:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config_dict)
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
@ -76,7 +68,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not user_id:
|
||||
user = EndUserService.get_or_create_end_user(app)
|
||||
else:
|
||||
user = cls._get_user(user_id, app)
|
||||
user = cls._get_user(user_id)
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
@ -87,10 +79,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
case AppMode.WORKFLOW:
|
||||
workflow = cls._get_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files)
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
case AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
case _:
|
||||
@ -112,7 +101,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = cls._get_workflow(app)
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@ -169,7 +158,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
app: App,
|
||||
workflow: Workflow,
|
||||
user: EndUser | Account,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
@ -178,6 +166,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
@ -215,26 +207,16 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str, app: App) -> EndUser | Account:
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
with create_session() as session:
|
||||
stmt = select(EndUser).where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == app.tenant_id,
|
||||
EndUser.app_id == app.id,
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(
|
||||
Account.id == user_id,
|
||||
Account.id == TenantAccountJoin.account_id,
|
||||
TenantAccountJoin.tenant_id == app.tenant_id,
|
||||
)
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = session.scalar(stmt)
|
||||
if user:
|
||||
session.expunge(user)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
@ -247,10 +229,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
with create_session() as session:
|
||||
app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
if app:
|
||||
session.expunge(app)
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
@ -258,41 +237,3 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app: App) -> Workflow | None:
|
||||
"""
|
||||
get workflow without relying on App.workflow's request-scoped session property
|
||||
"""
|
||||
if not app.workflow_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if workflow:
|
||||
session.expunge(workflow)
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None:
|
||||
"""
|
||||
get app model config features without relying on request-scoped session-backed model properties
|
||||
"""
|
||||
if not app.app_model_config_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
app_model_config = session.scalar(
|
||||
select(AppModelConfig)
|
||||
.where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if app_model_config is None:
|
||||
return None
|
||||
|
||||
annotation_reply = load_annotation_reply_config(session, app_model_config.app_id)
|
||||
return app_model_config.to_dict(annotation_reply=annotation_reply)
|
||||
|
||||
@ -14,12 +14,6 @@ from core.rag.extractor.watercrawl.exceptions import (
|
||||
|
||||
WATERCRAWL_REQUEST_TIMEOUT: httpx.Timeout = httpx.Timeout(30.0, connect=5.0)
|
||||
|
||||
# The crawl-status stream is a long-lived SSE connection that can stay open for
|
||||
# the whole duration of a crawl, so it keeps an unbounded read while still
|
||||
# capping the initial connection. Regular requests use WATERCRAWL_REQUEST_TIMEOUT
|
||||
# so a stalled endpoint can't hang a worker forever.
|
||||
_STREAM_TIMEOUT = httpx.Timeout(None, connect=10.0)
|
||||
|
||||
|
||||
class SpiderOptions(TypedDict):
|
||||
max_depth: int
|
||||
@ -56,8 +50,6 @@ class BaseAPIClient:
|
||||
"User-Agent": "WaterCrawl-Plugin",
|
||||
"Accept-Language": "en-US",
|
||||
}
|
||||
# Regular requests use WATERCRAWL_REQUEST_TIMEOUT; the long-lived
|
||||
# crawl-status stream overrides it with _STREAM_TIMEOUT in _request.
|
||||
return httpx.Client(headers=headers, timeout=WATERCRAWL_REQUEST_TIMEOUT)
|
||||
|
||||
def _request(
|
||||
@ -71,7 +63,7 @@ class BaseAPIClient:
|
||||
stream = kwargs.pop("stream", False)
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
if stream:
|
||||
request = self.session.build_request(method, url, params=query_params, json=data, timeout=_STREAM_TIMEOUT)
|
||||
request = self.session.build_request(method, url, params=query_params, json=data)
|
||||
return self.session.send(request, stream=True, **kwargs)
|
||||
|
||||
return self.session.request(method, url, params=query_params, json=data, **kwargs)
|
||||
|
||||
@ -22,35 +22,23 @@ class RBACPermission(StrEnum):
|
||||
|
||||
APP_VIEW_LAYOUT = "app_view_layout"
|
||||
APP_TEST_AND_RUN = "app_test_and_run"
|
||||
APP_PREVIEW = "app_preview"
|
||||
APP_CREATE_AND_MANAGEMENT = "app_create_and_management"
|
||||
APP_RELEASE_AND_VERSION = "app_release_and_version"
|
||||
APP_IMPORT_EXPORT_DSL = "app_import_export_dsl"
|
||||
APP_EDIT = "app_edit"
|
||||
APP_MONITOR = "app_monitor"
|
||||
APP_DELETE = "app_delete"
|
||||
APP_ACCESS_CONFIG = "app_access_config"
|
||||
|
||||
DATASET_PREVIEW = "dataset_preview"
|
||||
DATASET_READONLY = "dataset_readonly"
|
||||
DATASET_EDIT = "dataset_edit"
|
||||
DATASET_CREATE_AND_MANAGEMENT = "dataset_create_and_management"
|
||||
DATASET_PIPELINE_TEST = "dataset_pipeline_test"
|
||||
DATASET_DOCUMENT_DOWNLOAD = "dataset_document_download"
|
||||
DATASET_RETRIEVAL_RECALL = "dataset_retrieval_recall"
|
||||
DATASET_USE = "dataset_use"
|
||||
DATASET_DELETE_FILE = "dataset_delete_file"
|
||||
DATASET_PIPELINE_RELEASE = "dataset_pipeline_release"
|
||||
DATASET_DELETE = "dataset_delete"
|
||||
DATASET_ACCESS_CONFIG = "dataset_access_config"
|
||||
DATASET_API_KEY_MANAGE = "dataset_api_key_manage"
|
||||
DATASET_EXTERNAL_CONNECT = "dataset_external_connect"
|
||||
DATASET_IMPORT_EXPORT_DSL = "dataset_import_export_dsl"
|
||||
|
||||
WORKSPACE_MEMBER_MANAGE = "workspace_member_manage"
|
||||
WORKSPACE_ROLE_MANAGE = "workspace_role_manage"
|
||||
API_EXTENSION_MANAGE = "api_extension_manage"
|
||||
CUSTOMIZATION_MANAGE = "customization_manage"
|
||||
|
||||
SNIPPETS_CREATE_AND_MODIFY = "snippets_create_and_modify"
|
||||
SNIPPETS_MANAGE = "snippets_management"
|
||||
@ -61,7 +49,6 @@ class RBACPermission(StrEnum):
|
||||
PLUGIN_DEBUG = "plugin_debug"
|
||||
|
||||
CREDENTIAL_USE = "credential_use"
|
||||
CREDENTIAL_CREATE = "credential_create"
|
||||
CREDENTIAL_MANAGE = "credential_manage"
|
||||
|
||||
TOOL_MANAGE = "tool_manage"
|
||||
|
||||
@ -359,16 +359,15 @@ class ApiTool(Tool):
|
||||
if value is None:
|
||||
return None
|
||||
elif property["type"] == "object" or property["type"] == "array":
|
||||
match value:
|
||||
case str():
|
||||
try:
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
return value
|
||||
case dict():
|
||||
return value
|
||||
case _:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return json.loads(value)
|
||||
except ValueError:
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
raise ValueError(f"Invalid type {property['type']} for property {property}")
|
||||
elif "anyOf" in property and isinstance(property["anyOf"], list):
|
||||
|
||||
@ -12,61 +12,60 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import (
|
||||
FormDisposition,
|
||||
HumanInputSurface,
|
||||
disposition_for_surface,
|
||||
)
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
|
||||
def load_form_dispositions_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, FormDisposition]:
|
||||
"""Resolve each paused form's resume token and approval channels for `surface`."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
if not unique_form_ids:
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_dispositions_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_dispositions_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_dispositions_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> dict[str, FormDisposition]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token or "")
|
||||
)
|
||||
return {
|
||||
form_id: disposition_for_surface(recipients, surface=surface)
|
||||
for form_id, recipients in recipients_by_form_id.items()
|
||||
}
|
||||
|
||||
|
||||
def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Resume tokens only, for callers that don't surface approval channels."""
|
||||
dispositions = load_form_dispositions_by_form_id(form_ids, session=session, surface=surface)
|
||||
return {
|
||||
form_id: disposition.form_token
|
||||
for form_id, disposition in dispositions.items()
|
||||
if disposition.form_token is not None
|
||||
}
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
if not unique_form_ids:
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
@ -2,14 +2,14 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, NamedTuple
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType
|
||||
from graphon.nodes.human_input.entities import FormInputConfig, SelectInputConfig
|
||||
from graphon.nodes.human_input.enums import ValueSourceType
|
||||
from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool
|
||||
from graphon.variables import ArrayStringSegment
|
||||
from models.human_input import ApprovalChannel, RecipientType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class HumanInputSurface(StrEnum):
|
||||
@ -20,7 +20,7 @@ class HumanInputSurface(StrEnum):
|
||||
|
||||
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
@ -41,7 +41,7 @@ def is_recipient_type_allowed_for_surface(
|
||||
) -> bool:
|
||||
if recipient_type is None:
|
||||
return False
|
||||
return recipient_type in ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
return recipient_type in _ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
|
||||
|
||||
def get_preferred_form_token(
|
||||
@ -59,39 +59,10 @@ def get_preferred_form_token(
|
||||
return chosen_token
|
||||
|
||||
|
||||
class FormDisposition(NamedTuple):
|
||||
"""How a paused form resolves for one API surface.
|
||||
|
||||
A form's recipients split into those the surface may act on (yielding a resume
|
||||
`form_token`) and those it may not (their channels named in `approval_channels`
|
||||
so the caller is told where approval actually happens instead).
|
||||
"""
|
||||
|
||||
form_token: str | None
|
||||
approval_channels: list[ApprovalChannel]
|
||||
|
||||
|
||||
def disposition_for_surface(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> FormDisposition:
|
||||
if surface is None:
|
||||
return FormDisposition(form_token=get_preferred_form_token(recipients), approval_channels=[])
|
||||
allowed = ALLOWED_RECIPIENT_TYPES_BY_SURFACE[surface]
|
||||
actionable = [(recipient_type, token) for recipient_type, token in recipients if recipient_type in allowed]
|
||||
return FormDisposition(
|
||||
form_token=get_preferred_form_token(actionable),
|
||||
approval_channels=sorted(
|
||||
{recipient_type.approval_channel for recipient_type, _ in recipients if recipient_type not in allowed}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def enrich_human_input_pause_reasons(
|
||||
reasons: Sequence[Mapping[str, Any]],
|
||||
*,
|
||||
dispositions_by_form_id: Mapping[str, FormDisposition],
|
||||
form_tokens_by_form_id: Mapping[str, str],
|
||||
expiration_times_by_form_id: Mapping[str, int],
|
||||
) -> list[dict[str, Any]]:
|
||||
enriched: list[dict[str, Any]] = []
|
||||
@ -100,9 +71,7 @@ def enrich_human_input_pause_reasons(
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
disposition = dispositions_by_form_id.get(form_id)
|
||||
updated["form_token"] = disposition.form_token if disposition else None
|
||||
updated["approval_channels"] = list(disposition.approval_channels) if disposition else []
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
expiration_time = expiration_times_by_form_id.get(form_id)
|
||||
if expiration_time is not None:
|
||||
updated["expiration_time"] = expiration_time
|
||||
|
||||
@ -23,8 +23,6 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from core.workflow.human_input_adapter import adapt_node_config_for_graph
|
||||
from core.workflow.node_runtime import (
|
||||
DIFY_BEFORE_LLM_INVOKE_KEY,
|
||||
BeforeLLMInvoke,
|
||||
DifyFileReferenceFactory,
|
||||
DifyHumanInputNodeRuntime,
|
||||
DifyPreparedLLM,
|
||||
@ -531,19 +529,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(LLMCompatibleNodeData, node_data)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
before_llm_invoke = cast(
|
||||
BeforeLLMInvoke | None,
|
||||
self.graph_init_params.run_context.get(DIFY_BEFORE_LLM_INVOKE_KEY),
|
||||
)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
"model_factory": self._llm_model_factory,
|
||||
"model_instance": (
|
||||
self._wrap_model_instance_for_node(
|
||||
node_data=validated_node_data,
|
||||
model_instance=model_instance,
|
||||
before_invoke=before_llm_invoke,
|
||||
)
|
||||
self._wrap_model_instance_for_node(node_data=validated_node_data, model_instance=model_instance)
|
||||
if wrap_model_instance
|
||||
else model_instance
|
||||
),
|
||||
@ -575,14 +565,13 @@ class DifyNodeFactory(NodeFactory):
|
||||
*,
|
||||
node_data: LLMCompatibleNodeData,
|
||||
model_instance: ModelInstance,
|
||||
before_invoke: BeforeLLMInvoke | None = None,
|
||||
) -> DifyPreparedLLM:
|
||||
# Only graphon's LLM node consumes the polling protocol. Keep classifier
|
||||
# and extractor nodes on the existing wrapper even if the same model
|
||||
# advertises polling support.
|
||||
if node_data.type == BuiltinNodeTypes.LLM and DifyNodeFactory._supports_plugin_llm_polling(model_instance):
|
||||
return DifyPreparedPollingLLM(model_instance, before_invoke=before_invoke)
|
||||
return DifyPreparedLLM(model_instance, before_invoke=before_invoke)
|
||||
return DifyPreparedPollingLLM(model_instance)
|
||||
return DifyPreparedLLM(model_instance)
|
||||
|
||||
@staticmethod
|
||||
def _supports_plugin_llm_polling(model_instance: ModelInstance) -> bool:
|
||||
|
||||
@ -95,8 +95,6 @@ if TYPE_CHECKING:
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
|
||||
|
||||
DIFY_BEFORE_LLM_INVOKE_KEY = "_dify_before_llm_invoke"
|
||||
BeforeLLMInvoke = Callable[[Sequence[PromptMessage]], None]
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
|
||||
|
||||
@ -153,9 +151,8 @@ class DifyFileReferenceFactory(FileReferenceFactoryProtocol):
|
||||
class DifyPreparedLLM(LLMProtocol):
|
||||
"""Workflow-layer adapter that hides the full `ModelInstance` API from `graphon` nodes."""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance, before_invoke: BeforeLLMInvoke | None = None) -> None:
|
||||
def __init__(self, model_instance: ModelInstance) -> None:
|
||||
self._model_instance = model_instance
|
||||
self._before_invoke = before_invoke
|
||||
|
||||
@property
|
||||
@override
|
||||
@ -196,10 +193,6 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int:
|
||||
return self._model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
def _run_before_invoke(self, prompt_messages: Sequence[PromptMessage]) -> None:
|
||||
if self._before_invoke is not None:
|
||||
self._before_invoke(prompt_messages)
|
||||
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
@ -232,7 +225,6 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResult | Generator[LLMResultChunk, None, None]:
|
||||
self._run_before_invoke(prompt_messages)
|
||||
return self._model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=dict(model_parameters),
|
||||
@ -273,7 +265,6 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
stop: Sequence[str] | None,
|
||||
stream: bool,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
self._run_before_invoke(prompt_messages)
|
||||
return invoke_llm_with_structured_output(
|
||||
provider=self.provider,
|
||||
model_schema=self.get_model_schema(),
|
||||
@ -293,10 +284,10 @@ class DifyPreparedLLM(LLMProtocol):
|
||||
class DifyPreparedPollingLLM(DifyPreparedLLM, LLMPollingCapableProtocol):
|
||||
"""Prepared workflow LLM adapter that exposes Graphon's polling protocol."""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance, before_invoke: BeforeLLMInvoke | None = None) -> None:
|
||||
def __init__(self, model_instance: ModelInstance) -> None:
|
||||
from core.plugin.impl.model_runtime import PluginModelRuntime
|
||||
|
||||
super().__init__(model_instance, before_invoke=before_invoke)
|
||||
super().__init__(model_instance)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
if not isinstance(model_type_instance, LargeLanguageModel):
|
||||
raise TypeError("Polling wrapper requires a large-language-model instance.")
|
||||
@ -317,7 +308,6 @@ class DifyPreparedPollingLLM(DifyPreparedLLM, LLMPollingCapableProtocol):
|
||||
stop: Sequence[str] | None,
|
||||
json_schema: Mapping[str, Any] | None,
|
||||
) -> LLMPollingResult:
|
||||
self._run_before_invoke(prompt_messages)
|
||||
return self._plugin_model_runtime.start_llm_polling(
|
||||
provider=self.provider,
|
||||
model=self.model_name,
|
||||
|
||||
@ -5,7 +5,6 @@ def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
archive_workflow_runs,
|
||||
archive_workflow_runs_plan,
|
||||
backfill_plugin_auto_upgrade,
|
||||
clean_expired_messages,
|
||||
clean_workflow_runs,
|
||||
@ -73,7 +72,6 @@ def init_app(app: DifyApp):
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
archive_workflow_runs_plan,
|
||||
archive_workflow_runs,
|
||||
delete_archived_workflow_runs,
|
||||
restore_workflow_runs,
|
||||
|
||||
@ -25,7 +25,7 @@ from extensions.redis_names import (
|
||||
serialize_redis_name_args,
|
||||
)
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.pubsub_channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel
|
||||
|
||||
@ -457,14 +457,16 @@ def init_app(app: DifyApp):
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
join_timeout_ms=join_timeout_ms,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
|
||||
|
||||
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||
|
||||
@ -291,11 +291,6 @@ class AgentConfigSnapshotListResponse(ResponseModel):
|
||||
data: list[AgentConfigSnapshotSummaryResponse]
|
||||
|
||||
|
||||
class AgentConfigSnapshotRestoreResponse(ResponseModel):
|
||||
result: Literal["success"]
|
||||
active_config_snapshot_id: str
|
||||
|
||||
|
||||
class AgentComposerAgentResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from .pubsub_channel import BroadcastChannel
|
||||
from .channel import BroadcastChannel
|
||||
from .sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
__all__ = ["BroadcastChannel", "ShardedRedisBroadcastChannel"]
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Any, Self, override
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.client import PubSub
|
||||
|
||||
@ -27,6 +26,8 @@ class RedisSubscriptionBase(Subscription):
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
@ -38,6 +39,11 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._listener_thread: threading.Thread | None = None
|
||||
self._start_lock = threading.Lock()
|
||||
self._started = False
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency. The listener is a daemon
|
||||
# and exits on its own within one poll window (~1s), so a low value
|
||||
# here just means close() returns sooner without breaking anything.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def _start_if_needed(self) -> None:
|
||||
"""Start the subscription if not already started."""
|
||||
@ -84,11 +90,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
if raw_message is None:
|
||||
continue
|
||||
|
||||
# If close() sent a control event to unblock us, exit immediately
|
||||
# without processing any message — the subscription is shutting down.
|
||||
if self._closed.is_set():
|
||||
break
|
||||
|
||||
if raw_message.get("type") != self._get_message_type():
|
||||
continue
|
||||
|
||||
@ -118,8 +119,6 @@ class RedisSubscriptionBase(Subscription):
|
||||
continue
|
||||
|
||||
self._enqueue_message(payload_bytes)
|
||||
if payload_bytes == SIG_CLOSE:
|
||||
break
|
||||
|
||||
_logger.debug("%s listener thread stopped for channel %s", self._get_subscription_type().title(), self._topic)
|
||||
try:
|
||||
@ -165,20 +164,14 @@ class RedisSubscriptionBase(Subscription):
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
yield item
|
||||
|
||||
@override
|
||||
def __iter__(self) -> Iterator[bytes]:
|
||||
"""Return an iterator over messages from the subscription."""
|
||||
if self._closed.is_set():
|
||||
return iter(())
|
||||
try:
|
||||
self._start_if_needed()
|
||||
except SubscriptionClosedError:
|
||||
return iter(())
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
@override
|
||||
@ -215,55 +208,24 @@ class RedisSubscriptionBase(Subscription):
|
||||
@override
|
||||
def close(self) -> None:
|
||||
"""Close the subscription and clean up resources."""
|
||||
with self._start_lock:
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
listener = self._listener_thread
|
||||
self._listener_thread = None
|
||||
started = self._started
|
||||
|
||||
if started:
|
||||
self._unblock_message_iterator()
|
||||
|
||||
# Send a control event on the same Redis channel to unblock the
|
||||
self._publish_close_event()
|
||||
if self._closed.is_set():
|
||||
return
|
||||
|
||||
self._closed.set()
|
||||
# NOTE: PubSub is not thread-safe. More specifically, the `PubSub.close` method and the
|
||||
# message retrieval method should NOT be called concurrently.
|
||||
#
|
||||
# Due to the restriction above, the PubSub cleanup logic happens inside the consumer thread.
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=2)
|
||||
|
||||
def _unblock_message_iterator(self) -> None:
|
||||
try:
|
||||
self._queue.put_nowait(SIG_CLOSE)
|
||||
except queue.Full:
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
self._queue.put_nowait(SIG_CLOSE)
|
||||
except queue.Full:
|
||||
pass
|
||||
listener = self._listener_thread
|
||||
if listener is not None:
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
self._listener_thread = None
|
||||
|
||||
# Abstract methods to be implemented by subclasses
|
||||
def _get_subscription_type(self) -> str:
|
||||
"""Return the subscription type (e.g., 'regular' or 'sharded')."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish a control event on the Redis channel to unblock the listener.
|
||||
|
||||
This is called by close() after setting _closed. The subclass should
|
||||
publish an empty message on the same topic so that a blocking
|
||||
get_message() call in the listener thread returns promptly.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _subscribe(self) -> None:
|
||||
"""Subscribe to the Redis topic using the appropriate command."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BroadcastChannel:
|
||||
"""
|
||||
@ -26,11 +22,16 @@ class BroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
# See `RedisSubscriptionBase._join_timeout_ms`: how long close()
|
||||
# waits for the listener thread before returning.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> Topic:
|
||||
return Topic(self._client, topic)
|
||||
return Topic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class Topic:
|
||||
@ -38,10 +39,13 @@ class Topic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -57,6 +61,7 @@ class Topic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -67,13 +72,6 @@ class _RedisSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "regular"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.publish(self._topic, SIG_CLOSE)
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
@ -1,17 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ShardedRedisBroadcastChannel:
|
||||
"""
|
||||
@ -24,11 +20,14 @@ class ShardedRedisBroadcastChannel:
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> ShardedTopic:
|
||||
return ShardedTopic(self._client, topic)
|
||||
return ShardedTopic(self._client, topic, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
@ -36,10 +35,13 @@ class ShardedTopic:
|
||||
self,
|
||||
redis_client: Redis | RedisCluster,
|
||||
topic: str,
|
||||
*,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._redis_topic = serialize_redis_name(topic)
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
return self
|
||||
@ -55,6 +57,7 @@ class ShardedTopic:
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._redis_topic,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -65,13 +68,6 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
def _get_subscription_type(self) -> str:
|
||||
return "sharded"
|
||||
|
||||
@override
|
||||
def _publish_close_event(self) -> None:
|
||||
try:
|
||||
self._client.spublish(self._topic, SIG_CLOSE) # type: ignore[attr-defined,union-attr]
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def _subscribe(self) -> None:
|
||||
assert self._pubsub is not None
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Self, override
|
||||
from extensions.redis_names import serialize_redis_name
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from libs.broadcast_channel.signals import SIG_CLOSE
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -30,15 +29,20 @@ class StreamsBroadcastChannel:
|
||||
redis_client: Redis | RedisCluster,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._retention_seconds = max(int(retention_seconds or 0), 0)
|
||||
# Max time close() will wait for the listener thread to finish.
|
||||
# See `_StreamsSubscription._join_timeout_ms` for the rationale.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
def topic(self, topic: str) -> StreamsTopic:
|
||||
return StreamsTopic(
|
||||
self._client,
|
||||
topic,
|
||||
retention_seconds=self._retention_seconds,
|
||||
join_timeout_ms=self._join_timeout_ms,
|
||||
)
|
||||
|
||||
|
||||
@ -49,11 +53,13 @@ class StreamsTopic:
|
||||
topic: str,
|
||||
*,
|
||||
retention_seconds: int = 600,
|
||||
join_timeout_ms: int = 2000,
|
||||
):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
self._key = serialize_redis_name(f"stream:{topic}")
|
||||
self._retention_seconds = retention_seconds
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
self.max_length = 5000
|
||||
|
||||
def as_producer(self) -> Producer:
|
||||
@ -71,15 +77,23 @@ class StreamsTopic:
|
||||
return self
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _StreamsSubscription(self._client, self._key)
|
||||
return _StreamsSubscription(self._client, self._key, join_timeout_ms=self._join_timeout_ms)
|
||||
|
||||
|
||||
class _StreamsSubscription(Subscription):
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(self, client: Redis | RedisCluster, key: str):
|
||||
def __init__(self, client: Redis | RedisCluster, key: str, *, join_timeout_ms: int = 2000):
|
||||
self._client = client
|
||||
self._key = key
|
||||
# Max time close() will wait for the listener thread to finish before
|
||||
# returning. Bounds SSE close tail latency: the listener blocks on
|
||||
# XREAD with BLOCK=1000ms, so close() naturally waits up to ~1s for
|
||||
# the thread to notice _closed. Setting this lower lets close()
|
||||
# return promptly while the daemon listener exits on its own within
|
||||
# one BLOCK window - safe because the listener holds no critical
|
||||
# state. ``0`` means close() does not wait at all.
|
||||
self._join_timeout_ms = max(int(join_timeout_ms or 0), 0)
|
||||
|
||||
self._queue: queue.Queue[object] = queue.Queue()
|
||||
|
||||
@ -92,6 +106,7 @@ class _StreamsSubscription(Subscription):
|
||||
# reading and writing the _listener / `_closed` attribute.
|
||||
self._lock = threading.Lock()
|
||||
self._closed: bool = False
|
||||
# self._closed = threading.Event()
|
||||
self._listener: threading.Thread | None = None
|
||||
|
||||
def _listen(self) -> None:
|
||||
@ -129,8 +144,6 @@ class _StreamsSubscription(Subscription):
|
||||
case bytes() | bytearray():
|
||||
data_bytes = bytes(data)
|
||||
if data_bytes is not None:
|
||||
if data_bytes == SIG_CLOSE:
|
||||
break
|
||||
self._queue.put_nowait(data_bytes)
|
||||
last_id = entry_id
|
||||
finally:
|
||||
@ -190,13 +203,6 @@ class _StreamsSubscription(Subscription):
|
||||
assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue"
|
||||
return bytes(item)
|
||||
|
||||
def _publish_close_event(self) -> None:
|
||||
"""Publish an empty message to the stream to unblock the listener's xread."""
|
||||
try:
|
||||
self._client.xadd(self._key, {b"data": SIG_CLOSE})
|
||||
except Exception:
|
||||
logger.exception("failed to publish close event")
|
||||
|
||||
@override
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
@ -206,17 +212,16 @@ class _StreamsSubscription(Subscription):
|
||||
listener = self._listener
|
||||
if listener is not None:
|
||||
self._listener = None
|
||||
|
||||
if listener is not None:
|
||||
self._publish_close_event()
|
||||
|
||||
# We close the listener outside of the with block to avoid holding the
|
||||
# lock for a long time.
|
||||
if listener is not None and listener.is_alive():
|
||||
listener.join(timeout=2)
|
||||
listener.join(timeout=self._join_timeout_ms / 1000.0)
|
||||
if listener.is_alive():
|
||||
logger.debug(
|
||||
"Streams subscription listener for key %s did not stop after join; "
|
||||
"Streams subscription listener for key %s did not stop within %dms; "
|
||||
"daemon thread will exit on its own within one poll window.",
|
||||
self._key,
|
||||
self._join_timeout_ms,
|
||||
)
|
||||
|
||||
# Context manager helpers
|
||||
|
||||
@ -1 +0,0 @@
|
||||
SIG_CLOSE = b"__closed__"
|
||||
@ -1,39 +0,0 @@
|
||||
"""agent drive skill metadata refactor
|
||||
|
||||
Revision ID: b2515f9d4c2a
|
||||
Revises: 4f7b2c8d9a10
|
||||
Create Date: 2026-06-18 23:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b2515f9d4c2a"
|
||||
down_revision = "4f7b2c8d9a10"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_drive_files",
|
||||
sa.Column("is_skill", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_drive_files",
|
||||
sa.Column("skill_metadata", sa.Text().with_variant(mysql.LONGTEXT(), "mysql"), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_drive_files_tenant_agent_is_skill_key_idx",
|
||||
"agent_drive_files",
|
||||
["tenant_id", "agent_id", "is_skill", "key"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("agent_drive_files_tenant_agent_is_skill_key_idx", table_name="agent_drive_files")
|
||||
op.drop_column("agent_drive_files", "skill_metadata")
|
||||
op.drop_column("agent_drive_files", "is_skill")
|
||||
@ -1,66 +0,0 @@
|
||||
"""add agent debug conversations
|
||||
|
||||
Revision ID: c8f4a6b2d3e1
|
||||
Revises: b2515f9d4c2a
|
||||
Create Date: 2026-06-22 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c8f4a6b2d3e1"
|
||||
down_revision = "b2515f9d4c2a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
def _uuid_column(name: str, *, nullable: bool = False, primary_key: bool = False) -> sa.Column:
|
||||
kwargs = {"nullable": nullable, "primary_key": primary_key}
|
||||
if primary_key and _is_pg(op.get_bind()):
|
||||
kwargs["server_default"] = sa.text("uuidv7()")
|
||||
return sa.Column(name, models.types.StringUUID(), **kwargs)
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"agent_debug_conversations",
|
||||
_uuid_column("id", primary_key=True),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("agent_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("account_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("conversation_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("agent_debug_conversation_pkey")),
|
||||
sa.UniqueConstraint(
|
||||
"tenant_id",
|
||||
"agent_id",
|
||||
"account_id",
|
||||
name=op.f("agent_debug_conversation_agent_account_unique"),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"agent_debug_conversation_conversation_idx",
|
||||
"agent_debug_conversations",
|
||||
["conversation_id"],
|
||||
)
|
||||
op.create_index(
|
||||
"agent_debug_conversation_account_idx",
|
||||
"agent_debug_conversations",
|
||||
["tenant_id", "account_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("agent_debug_conversation_account_idx", table_name="agent_debug_conversations")
|
||||
op.drop_index("agent_debug_conversation_conversation_idx", table_name="agent_debug_conversations")
|
||||
op.drop_table("agent_debug_conversations")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user