mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 17:37:39 +08:00
Compare commits
185 Commits
feat(agent
...
feat/cli-e
| Author | SHA1 | Date | |
|---|---|---|---|
| 11e464ddae | |||
| 971f7b964b | |||
| 33af0f55c2 | |||
| 2f60dd6ca5 | |||
| ada1da1781 | |||
| 777ba22431 | |||
| 0d66bbefc3 | |||
| 7306fa4c50 | |||
| fe91ccfe5d | |||
| 2d22d87970 | |||
| 56ba8f421a | |||
| 3133b196ad | |||
| 3cc20de830 | |||
| c8abb11bf0 | |||
| f9320b2c91 | |||
| f0fd7ddb60 | |||
| b77f5f1e4a | |||
| b67c3a5f76 | |||
| 5b5a06136a | |||
| 6e3c9597ff | |||
| 3c98f96ae8 | |||
| 44725dde74 | |||
| d3058d63bd | |||
| 4fc62d3b38 | |||
| e14cb209a4 | |||
| bb3c9929f9 | |||
| 35a55813d2 | |||
| a247d625e5 | |||
| 5c7f05bd10 | |||
| 02e1a60cde | |||
| 57b573d02b | |||
| 9de40e8f21 | |||
| cad0942f4d | |||
| cb9b1b593e | |||
| 2a8bdc2373 | |||
| ee6a07d13c | |||
| 2d6c9300e3 | |||
| d6b4c800c2 | |||
| 1b37635f92 | |||
| 86af36429d | |||
| b96ea94505 | |||
| d649cccda0 | |||
| 5cbbd78f38 | |||
| 5a0ad4ecd9 | |||
| 1e76b9e1b8 | |||
| 1b972c4e09 | |||
| 7968d2c3c8 | |||
| 7507e9ba67 | |||
| ca31762e26 | |||
| f591da7865 | |||
| f19679b217 | |||
| b682591c7a | |||
| 8f6b59feff | |||
| 99833f65d8 | |||
| 696fc5c213 | |||
| eae44cfecb | |||
| dea4e66456 | |||
| 3cd0da303a | |||
| 888483a2f8 | |||
| 7056985f72 | |||
| 6ce61eae59 | |||
| 079af312c6 | |||
| 0da13dfe4d | |||
| 1ff4d75084 | |||
| e35d23c3cb | |||
| e530e84772 | |||
| 2257a4f1ef | |||
| f465dc5090 | |||
| 5c1cfe6ada | |||
| 8d401d84c7 | |||
| 363aabee73 | |||
| b74287c2ab | |||
| e61073ccd5 | |||
| c64d3e98c4 | |||
| 748d790a0d | |||
| 0f52c5e6f3 | |||
| a3265f722e | |||
| 5658065b97 | |||
| 8fc2807194 | |||
| d9b928577c | |||
| fc7716704d | |||
| 71ffaacb58 | |||
| cfc1cf2b8c | |||
| 055d9b9f0a | |||
| 21711bebeb | |||
| 400befc451 | |||
| 4649e52384 | |||
| becccbf288 | |||
| c045e0b635 | |||
| cf7859cbf9 | |||
| 81d2c1638f | |||
| 69923a16e1 | |||
| 7114415cfd | |||
| 6c8ec0b1c8 | |||
| 86497045c9 | |||
| 687a177b24 | |||
| 4a6d278354 | |||
| 7d69302e9f | |||
| bcd573e560 | |||
| 07c0c4e7b1 | |||
| a8a2ca7b98 | |||
| de47d43b65 | |||
| 240912cef5 | |||
| 72e040ead3 | |||
| c0ee821d45 | |||
| c7c3296572 | |||
| e7be04fd58 | |||
| df6b5be50a | |||
| 8e5f09091b | |||
| 0a3005701f | |||
| d8571ce965 | |||
| f241ae25be | |||
| c6474a2a8b | |||
| 480d05bc48 | |||
| f75725ccd9 | |||
| 2fe8c48255 | |||
| ec5404cc9d | |||
| 20f62b9919 | |||
| 04f5555580 | |||
| 129af96c23 | |||
| df40960f5d | |||
| 599960024d | |||
| 6805d9bfc0 | |||
| 928f888ef5 | |||
| f46c03460e | |||
| 0b60338ad5 | |||
| 91ac465982 | |||
| 9490d63c50 | |||
| ae538ced47 | |||
| 487249728b | |||
| 372a2e3e9c | |||
| 4939a9c33d | |||
| b6f92f1dc4 | |||
| ce276573a8 | |||
| 5070cc9668 | |||
| a392a72960 | |||
| 30270b5c30 | |||
| 24715a9570 | |||
| c530a5d272 | |||
| 418ee7398e | |||
| 78f40c0d25 | |||
| 2cc567c6a3 | |||
| a180ab19e4 | |||
| 13eaa436e7 | |||
| 5ff98b97df | |||
| 982ada6f4e | |||
| e0d5bc48d9 | |||
| 3596d12e4c | |||
| e8de10a3b5 | |||
| f5ab5e7eb3 | |||
| 0c40e1c2a0 | |||
| c29d76757e | |||
| 91c1d3ad81 | |||
| 57b02e341c | |||
| b94ff65e9f | |||
| 678260e34e | |||
| 739e34d08a | |||
| 825fb9cb89 | |||
| 0e1f19a380 | |||
| 332d1ea533 | |||
| 9cdeffd0b1 | |||
| 09ef785a20 | |||
| d2788d7aba | |||
| cee90a4e82 | |||
| b2710b875b | |||
| e0e0ae372a | |||
| bc3b1c0c81 | |||
| 6464255d33 | |||
| 50face5760 | |||
| b034449a0c | |||
| a8d380bcaf | |||
| bee21c9f86 | |||
| cab215e209 | |||
| 7ae4ca9a60 | |||
| d342ff1a1e | |||
| 4384d8910e | |||
| b734afd609 | |||
| fc773b9f57 | |||
| 6e1e0d9439 | |||
| 5c5a6e83e5 | |||
| dade318f00 | |||
| 5646bda88e | |||
| ebff9a3639 | |||
| 58b8fc21d4 | |||
| e0ad088657 |
@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-query-mutation
|
||||
1
.claude/skills/how-to-write-component
Symbolic link
1
.claude/skills/how-to-write-component
Symbolic link
@ -0,0 +1 @@
|
||||
../../.agents/skills/how-to-write-component
|
||||
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -166,6 +166,7 @@
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
|
||||
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@ -51,6 +51,15 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
- name: Check dify-agent inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: dify-agent-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
dify-agent/**/*.py
|
||||
dify-agent/pyproject.toml
|
||||
dify-agent/uv.lock
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
@ -76,6 +85,17 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- if: github.event_name != 'merge_group' && steps.dify-agent-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
cd dify-agent
|
||||
uv sync --dev
|
||||
# fmt first to avoid line too long
|
||||
uv run ruff format .
|
||||
# Fix lint errors
|
||||
uv run ruff check --fix .
|
||||
# Format code
|
||||
uv run ruff format .
|
||||
|
||||
- name: count migration progress
|
||||
if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true'
|
||||
run: |
|
||||
|
||||
415
.github/workflows/cli-e2e.yml
vendored
Normal file
415
.github/workflows/cli-e2e.yml
vendored
Normal file
@ -0,0 +1,415 @@
|
||||
name: CLI E2E Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
cli_ref:
|
||||
description: "Git ref (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
edition:
|
||||
description: "Dify edition"
|
||||
type: choice
|
||||
required: false
|
||||
default: ee
|
||||
options: [ee, ce]
|
||||
|
||||
test_scope:
|
||||
description: "smoke = [P0] only / full = all cases"
|
||||
type: choice
|
||||
required: false
|
||||
default: full
|
||||
options: [smoke, full]
|
||||
|
||||
# ── Suite on/off ────────────────────────────────────────────────────────
|
||||
suite_framework_output_error:
|
||||
description: "framework + output + error-handling suites"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_discovery:
|
||||
description: "discovery suite (get app / describe app)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_run:
|
||||
description: "run suite (basic / streaming / conversation / file / hitl)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_auth:
|
||||
description: "auth suite (login / status / whoami / use / devices / logout)"
|
||||
type: boolean
|
||||
default: true
|
||||
suite_agent:
|
||||
description: "agent suite"
|
||||
type: boolean
|
||||
default: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
# ── Shared env injected into every E2E job ───────────────────────────────────
|
||||
# Each job reads DIFY_E2E_TOKEN + app IDs from the provision job outputs,
|
||||
# so global-setup skips minting and finds existing apps in < 10 s.
|
||||
env:
|
||||
DIFY_E2E_NO_KEYRING: "1" # Linux CI has no keychain; skip probe
|
||||
VITEST_RETRY: "2" # Retry flaky staging responses
|
||||
|
||||
jobs:
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 0. PROVISION — mint token + import DSL fixtures (runs once, outputs IDs)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
provision:
|
||||
name: "Provision: mint token + DSL apps"
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
token: ${{ steps.out.outputs.DIFY_E2E_TOKEN }}
|
||||
workspace_id: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_ID }}
|
||||
workspace_name: ${{ steps.out.outputs.DIFY_E2E_WORKSPACE_NAME }}
|
||||
ws2_id: ${{ steps.out.outputs.DIFY_E2E_WS2_ID }}
|
||||
chat_app_id: ${{ steps.out.outputs.DIFY_E2E_CHAT_APP_ID }}
|
||||
workflow_app_id: ${{ steps.out.outputs.DIFY_E2E_WORKFLOW_APP_ID }}
|
||||
file_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_APP_ID }}
|
||||
file_chat_app_id: ${{ steps.out.outputs.DIFY_E2E_FILE_CHAT_APP_ID }}
|
||||
hitl_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_APP_ID }}
|
||||
hitl_external_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_EXTERNAL_APP_ID }}
|
||||
hitl_single_action_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_SINGLE_ACTION_APP_ID }}
|
||||
hitl_multi_node_app_id: ${{ steps.out.outputs.DIFY_E2E_HITL_MULTI_NODE_APP_ID }}
|
||||
ws2_app_id: ${{ steps.out.outputs.DIFY_E2E_WS2_APP_ID }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with:
|
||||
package_json_field: packageManager
|
||||
run_install: false
|
||||
|
||||
- name: Install CLI dependencies
|
||||
working-directory: cli
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Mint token & provision apps
|
||||
id: out
|
||||
working-directory: cli
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_TOKEN: ${{ secrets.DIFY_E2E_TOKEN }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
run: bun scripts/e2e-provision.ts
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-B. framework + output + error-handling (parallel with run/discovery)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-framework-output-error:
|
||||
name: "Suite: framework + output + error-handling"
|
||||
if: ${{ inputs.suite_framework_output_error == true }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run framework + output + error-handling
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/framework/**/*.e2e.ts,test/e2e/suites/output/**/*.e2e.ts,test/e2e/suites/error-handling/**/*.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-C. Discovery (parallel)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-discovery:
|
||||
name: "Suite: discovery"
|
||||
if: ${{ inputs.suite_discovery == true }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run discovery suite
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/discovery/**/*.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-D. Run suite — 5 files in matrix (parallel)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-run:
|
||||
name: "Suite: run / ${{ matrix.name }}"
|
||||
if: ${{ inputs.suite_run == true }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 20
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: basic
|
||||
file: run-app-basic.e2e.ts
|
||||
- name: streaming
|
||||
file: run-app-streaming.e2e.ts
|
||||
- name: conversation
|
||||
file: run-app-conversation.e2e.ts
|
||||
- name: file
|
||||
file: run-app-file.e2e.ts
|
||||
- name: hitl
|
||||
file: run-app-hitl.e2e.ts
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: "Run run/${{ matrix.name }}"
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_SSO_TOKEN: ${{ secrets.DIFY_E2E_SSO_TOKEN }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_FILE_APP_ID: ${{ needs.provision.outputs.file_app_id }}
|
||||
DIFY_E2E_FILE_CHAT_APP_ID: ${{ needs.provision.outputs.file_chat_app_id }}
|
||||
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
|
||||
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
|
||||
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
|
||||
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/run/${{ matrix.file }}"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
- name: Upload results on failure
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e-run-${{ matrix.name }}-${{ github.run_id }}
|
||||
path: cli/test-results/
|
||||
retention-days: 3
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 1-E. auth/login + status + whoami (parallel, read-only, safe)
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-auth-safe:
|
||||
name: "Suite: auth (login / status / whoami)"
|
||||
if: ${{ inputs.suite_auth == true }}
|
||||
needs: provision
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run auth/login + status + whoami
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_INCLUDE: "test/e2e/suites/auth/login.e2e.ts,test/e2e/suites/auth/status.e2e.ts,test/e2e/suites/auth/whoami.e2e.ts"
|
||||
run: |
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
pnpm test:e2e
|
||||
fi
|
||||
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
# 2. DESTRUCTIVE — auth/use + devices + logout + agent (serial, runs LAST)
|
||||
# Must wait for ALL parallel suites to finish to avoid token revocation
|
||||
# invalidating other in-flight requests.
|
||||
# ════════════════════════════════════════════════════════════════════════════
|
||||
suite-last:
|
||||
name: "Suite: auth-use + devices + logout + agent (last, serial)"
|
||||
# Runs when auth is selected; also runs after all parallel jobs finish
|
||||
if: ${{ inputs.suite_auth == true || inputs.suite_agent == true }}
|
||||
needs:
|
||||
- provision
|
||||
- suite-framework-output-error
|
||||
- suite-discovery
|
||||
- suite-run
|
||||
- suite-auth-safe
|
||||
# `needs` on a skipped job is treated as success — safe to proceed even if
|
||||
# some suites were disabled via toggle.
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 25
|
||||
defaults:
|
||||
run:
|
||||
working-directory: cli
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- uses: ./.github/actions/setup-web
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with: { bun-version: latest }
|
||||
- uses: pnpm/action-setup@b906affcce14559ad1aafd4ab0e942779e9f58b1 # v4
|
||||
with: { package_json_field: packageManager, run_install: false }
|
||||
- run: pnpm install --frozen-lockfile
|
||||
- run: pnpm tree:gen
|
||||
|
||||
- name: Run use / devices / logout / agent (serial)
|
||||
env:
|
||||
DIFY_E2E_HOST: ${{ secrets.DIFY_E2E_HOST }}
|
||||
DIFY_E2E_EMAIL: ${{ secrets.DIFY_E2E_EMAIL }}
|
||||
DIFY_E2E_PASSWORD: ${{ secrets.DIFY_E2E_PASSWORD }}
|
||||
DIFY_E2E_EDITION: ${{ inputs.edition || 'ee' }}
|
||||
DIFY_E2E_TOKEN: ${{ needs.provision.outputs.token }}
|
||||
DIFY_E2E_WORKSPACE_ID: ${{ needs.provision.outputs.workspace_id }}
|
||||
DIFY_E2E_WORKSPACE_NAME: ${{ needs.provision.outputs.workspace_name }}
|
||||
DIFY_E2E_WS2_ID: ${{ needs.provision.outputs.ws2_id }}
|
||||
DIFY_E2E_CHAT_APP_ID: ${{ needs.provision.outputs.chat_app_id }}
|
||||
DIFY_E2E_WORKFLOW_APP_ID: ${{ needs.provision.outputs.workflow_app_id }}
|
||||
DIFY_E2E_HITL_APP_ID: ${{ needs.provision.outputs.hitl_app_id }}
|
||||
DIFY_E2E_HITL_EXTERNAL_APP_ID: ${{ needs.provision.outputs.hitl_external_app_id }}
|
||||
DIFY_E2E_HITL_SINGLE_ACTION_APP_ID: ${{ needs.provision.outputs.hitl_single_action_app_id }}
|
||||
DIFY_E2E_HITL_MULTI_NODE_APP_ID: ${{ needs.provision.outputs.hitl_multi_node_app_id }}
|
||||
run: |
|
||||
# Collect files in safe order: use → devices → logout (revokes last) → agent
|
||||
FILES=()
|
||||
if [ "${{ inputs.suite_auth }}" = "true" ]; then
|
||||
FILES+=(
|
||||
test/e2e/suites/auth/use.e2e.ts
|
||||
test/e2e/suites/auth/devices.e2e.ts
|
||||
test/e2e/suites/auth/logout.e2e.ts
|
||||
)
|
||||
fi
|
||||
if [ "${{ inputs.suite_agent }}" = "true" ]; then
|
||||
while IFS= read -r f; do FILES+=("$f"); done \
|
||||
< <(find test/e2e/suites/agent -name '*.e2e.ts' | sort)
|
||||
fi
|
||||
|
||||
[ ${#FILES[@]} -eq 0 ] && { echo "Nothing to run."; exit 0; }
|
||||
|
||||
# Pass files via DIFY_E2E_INCLUDE (comma-separated) so vitest
|
||||
# config's include list is overridden instead of ANDed.
|
||||
INCLUDE=$(IFS=,; echo "${FILES[*]}")
|
||||
if [ "${{ inputs.test_scope }}" = "smoke" ]; then
|
||||
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e -- -t "\[P0\]"
|
||||
else
|
||||
DIFY_E2E_INCLUDE="$INCLUDE" pnpm test:e2e
|
||||
fi
|
||||
|
||||
- name: Upload results on failure
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: e2e-last-${{ github.run_id }}
|
||||
path: cli/test-results/
|
||||
retention-days: 3
|
||||
10
.github/workflows/cli-tests.yml
vendored
10
.github/workflows/cli-tests.yml
vendored
@ -15,8 +15,12 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
name: CLI Tests (${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [depot-ubuntu-24.04, windows-latest, macos-latest]
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -37,7 +41,7 @@ jobs:
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: Deploy Agent Dev
|
||||
name: Deploy SaaS
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/agent-dev"
|
||||
- "deploy/saas"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,13 +16,13 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/saas'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
host: ${{ secrets.SAAS_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
${{ vars.SSH_SCRIPT_SAAS_DEV || secrets.SSH_SCRIPT_SAAS_DEV }}
|
||||
63
.github/workflows/style.yml
vendored
63
.github/workflows/style.yml
vendored
@ -95,6 +95,51 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: vp run knip
|
||||
|
||||
ts-common-style:
|
||||
name: TS Common
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
checks: write
|
||||
pull-requests: read
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
cli/**
|
||||
e2e/**
|
||||
sdks/nodejs-client/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
eslint.config.mjs
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
@ -105,28 +150,14 @@ jobs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
- name: Style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run lint:ci
|
||||
|
||||
- name: Web tsslint
|
||||
- name: Type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: .
|
||||
run: vp run type-check
|
||||
|
||||
- name: Web dead code check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run knip
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -259,3 +259,6 @@ scripts/stress-test/reports/
|
||||
.qoder/*
|
||||
.context/
|
||||
.eslintcache
|
||||
|
||||
# Vitest local reports
|
||||
web/.vitest-reports/
|
||||
|
||||
27
SECURITY.md
Normal file
27
SECURITY.md
Normal file
@ -0,0 +1,27 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you believe you have found a security vulnerability in Dify, please report it privately through GitHub Security Advisories:
|
||||
|
||||
https://github.com/langgenius/dify/security/advisories/new
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.
|
||||
|
||||
When submitting a report, include as much relevant information as you can safely provide, such as:
|
||||
|
||||
- A description of the vulnerability
|
||||
- Steps to reproduce, if safe to share privately
|
||||
- Affected components, versions, or configurations
|
||||
- Potential impact
|
||||
- Any suggested mitigation or fix, if available
|
||||
|
||||
The maintainers will review reports submitted through GitHub Security Advisories and coordinate follow-up there.
|
||||
|
||||
## Public Disclosure
|
||||
|
||||
Please avoid publicly disclosing details of a vulnerability until it has been reviewed and, where appropriate, a fix or mitigation has been made available.
|
||||
|
||||
## Security Updates
|
||||
|
||||
Security fixes may be released through normal project releases or other appropriate channels. Users are encouraged to keep Dify deployments up to date.
|
||||
@ -17,7 +17,7 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
git g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -223,10 +223,11 @@ def initialize_extensions(app: DifyApp):
|
||||
|
||||
def create_migrations_app() -> DifyApp:
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_database, ext_migrate
|
||||
from extensions import ext_commands, ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
ext_commands.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@ -31,20 +31,26 @@ from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAge
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendAgentAppRunInput,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendAgentAppRunInput",
|
||||
"AgentBackendError",
|
||||
"AgentBackendHTTPError",
|
||||
"AgentBackendInternalEvent",
|
||||
@ -66,9 +72,11 @@ __all__ = [
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"CleanupLayerSpec",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"extract_cleanup_layer_specs",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
|
||||
@ -20,6 +20,8 @@ from dify_agent.protocol import (
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunPausedEvent,
|
||||
RunPausedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
@ -34,6 +36,7 @@ class FakeAgentBackendScenario(StrEnum):
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
@ -89,6 +92,13 @@ class FakeAgentBackendRunClient:
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="paused",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
@ -115,3 +125,17 @@ class FakeAgentBackendRunClient:
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunPausedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunPausedEventData(
|
||||
reason="human_input_required",
|
||||
message="Agent requested human input.",
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -11,22 +11,28 @@ composition-driven.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, cast
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLLMLayerConfig,
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.layers.shell import DIFY_SHELL_LAYER_TYPE_ID, DifyShellLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
@ -40,7 +46,88 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
AGENT_APP_USER_PROMPT_LAYER_ID = "agent_app_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
DIFY_SHELL_LAYER_ID = "shell"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
# the cleanup request) because we deliberately do not persist plaintext
|
||||
# credentials between runs.
|
||||
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
)
|
||||
|
||||
|
||||
class CleanupLayerSpec(BaseModel):
|
||||
"""One layer node replayed by an Agent backend cleanup-only run.
|
||||
|
||||
Cleanup composition cannot include credential-bearing plugin layers, so we
|
||||
persist only the non-plugin layer specs together with the original config.
|
||||
Storing the config (rather than just ``name``/``type``) means cleanup does
|
||||
not depend on the original build-time inputs being re-derivable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
deps: dict[str, str] = Field(default_factory=dict)
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
config: JsonValue = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
|
||||
"""Project the in-flight composition into the persistable cleanup spec list.
|
||||
|
||||
Plugin layers are intentionally dropped (their configs hold credentials and
|
||||
the lifecycle contract says "do not include an LLM layer" during cleanup).
|
||||
The filtered names must later drive snapshot filtering so the agenton
|
||||
compositor's name-order check still passes for the cleanup run.
|
||||
"""
|
||||
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
|
||||
specs: list[CleanupLayerSpec] = []
|
||||
for layer in composition.layers:
|
||||
if layer.type in excluded:
|
||||
continue
|
||||
config_value: JsonValue = None
|
||||
if isinstance(layer.config, BaseModel):
|
||||
config_value = layer.config.model_dump(mode="json", warnings=False)
|
||||
else:
|
||||
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
|
||||
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
|
||||
# pipeline our builder only emits BaseModel-derived configs or
|
||||
# ``None``, so the wider input alias narrows safely here.
|
||||
config_value = cast(JsonValue, layer.config)
|
||||
specs.append(
|
||||
CleanupLayerSpec(
|
||||
name=layer.name,
|
||||
type=layer.type,
|
||||
deps=dict(layer.deps),
|
||||
metadata=dict(layer.metadata),
|
||||
config=config_value,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def _filter_snapshot_to_specs(
|
||||
snapshot: CompositorSessionSnapshot,
|
||||
specs: list[CleanupLayerSpec],
|
||||
) -> CompositorSessionSnapshot:
|
||||
"""Keep only snapshot layers whose names appear in the cleanup spec list.
|
||||
|
||||
The agenton compositor rejects a snapshot whose layer-name sequence does
|
||||
not match the active composition exactly. Cleanup-replay drops plugin
|
||||
layers, so we must drop the matching snapshot entries here.
|
||||
"""
|
||||
kept_names = {spec.name for spec in specs}
|
||||
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
|
||||
if len(filtered_layers) == len(snapshot.layers):
|
||||
return snapshot
|
||||
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
@ -81,8 +168,14 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
purpose: RunPurpose = "workflow_node"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
suspend_on_exit: bool = False
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
@ -95,9 +188,198 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendAgentAppRunInput(BaseModel):
|
||||
"""Inputs to build one Agent App conversation-turn run request.
|
||||
|
||||
Unlike the workflow-node input there is no workflow-node-job prompt and no
|
||||
previous-node context: the user prompt is the chat message, and multi-turn
|
||||
continuity comes from ``session_snapshot`` + the history layer keyed by the
|
||||
conversation.
|
||||
"""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "agent_app"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
# Inject the sandboxed shell layer (dify.shell). Requires the agent backend
|
||||
# to be wired with a shellctl entrypoint; see configs AGENT_SHELL_ENABLED.
|
||||
include_shell: bool = False
|
||||
shell_config: DifyShellLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
|
||||
@field_validator("user_prompt")
|
||||
@classmethod
|
||||
def _reject_blank_prompt(cls, value: str) -> str:
|
||||
if not value.strip():
|
||||
raise ValueError("prompt must not be blank")
|
||||
return value
|
||||
|
||||
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_for_agent_app(self, run_input: AgentBackendAgentAppRunInput) -> CreateRunRequest:
|
||||
"""Build an Agent App conversation-turn run request.
|
||||
|
||||
Layer graph: optional Agent Soul system prompt → user prompt →
|
||||
execution context → optional history (multi-turn) → LLM → optional
|
||||
plugin tools → optional structured output. Mirrors the workflow-node
|
||||
layer ordering minus the workflow-job / previous-node prompt.
|
||||
"""
|
||||
layers: list[RunLayerSpec] = []
|
||||
if run_input.agent_soul_prompt:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_soul"},
|
||||
config=PromptLayerConfig(prefix=run_input.agent_soul_prompt),
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=AGENT_APP_USER_PROMPT_LAYER_ID,
|
||||
type=PLAIN_PROMPT_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_app_user_prompt"},
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
type=DIFY_OUTPUT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
session_snapshot=run_input.session_snapshot,
|
||||
on_exit=LayerExitSignals(
|
||||
default=ExitIntent.SUSPEND if run_input.suspend_on_exit else ExitIntent.DELETE,
|
||||
),
|
||||
)
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
session_snapshot: CompositorSessionSnapshot,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
idempotency_key: str | None = None,
|
||||
metadata: dict[str, JsonValue] | None = None,
|
||||
) -> CreateRunRequest:
|
||||
"""Build a lifecycle-only cleanup request that replays the prior layers.
|
||||
|
||||
The agenton compositor enforces that the session snapshot's layer names
|
||||
match the active composition in order, so cleanup must replay the same
|
||||
non-plugin layer graph that produced the snapshot. Plugin layers
|
||||
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
|
||||
composition and the snapshot before submission because their configs
|
||||
require credentials that are not persisted between runs.
|
||||
"""
|
||||
if not composition_layer_specs:
|
||||
raise ValueError(
|
||||
"build_cleanup_request requires composition_layer_specs; an empty "
|
||||
"composition would fail the agent backend's snapshot validation."
|
||||
)
|
||||
request_metadata = dict(metadata or {})
|
||||
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
|
||||
layers = [
|
||||
RunLayerSpec(
|
||||
name=spec.name,
|
||||
type=spec.type,
|
||||
deps=dict(spec.deps),
|
||||
metadata=dict(spec.metadata),
|
||||
config=spec.config,
|
||||
)
|
||||
for spec in composition_layer_specs
|
||||
]
|
||||
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose="workflow_node",
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=request_metadata,
|
||||
session_snapshot=filtered_snapshot,
|
||||
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
|
||||
)
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
@ -131,6 +413,20 @@ class AgentBackendRunRequestBuilder:
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
@ -147,6 +443,29 @@ class AgentBackendRunRequestBuilder:
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.include_shell:
|
||||
# Sandboxed bash workspace (dify.shell). The layer declares NoLayerDeps,
|
||||
# so the spec carries no deps; shellctl connection is server-injected.
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_SHELL_LAYER_ID,
|
||||
type=DIFY_SHELL_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.shell_config or DifyShellLayerConfig(),
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
|
||||
135
api/clients/agent_backend/workspace_files_client.py
Normal file
135
api/clients/agent_backend/workspace_files_client.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""API-side client for the agent backend's read-only workspace file endpoints.
|
||||
|
||||
The agent backend exposes ``/workspaces/{session_id}/files{,/preview,/download}``
|
||||
to inspect a shell-layer sandbox workspace. This thin synchronous client proxies
|
||||
those reads for the console FS inspector and normalizes transport/HTTP failures
|
||||
into the API backend's ``AgentBackendError`` boundary, preserving the backend's
|
||||
status code and ``{code, message}`` detail so the controller can relay them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
|
||||
_DEFAULT_TIMEOUT_SECONDS = 30.0
|
||||
|
||||
|
||||
class WorkspaceFileEntry(BaseModel):
|
||||
"""One entry in a workspace directory listing."""
|
||||
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResult(BaseModel):
|
||||
"""Directory listing of a workspace path."""
|
||||
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntry]
|
||||
truncated: bool
|
||||
|
||||
|
||||
class WorkspacePreviewResult(BaseModel):
|
||||
"""Inline preview of a workspace file."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkspaceDownloadResult:
|
||||
"""Decoded bytes of a workspace file for download."""
|
||||
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
content: bytes
|
||||
|
||||
|
||||
class WorkspaceFilesBackendClient:
|
||||
"""Synchronous proxy to the agent backend workspace file endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
*,
|
||||
timeout: float = _DEFAULT_TIMEOUT_SECONDS,
|
||||
transport: httpx.BaseTransport | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._transport = transport
|
||||
|
||||
def list_files(self, session_id: str, path: str) -> WorkspaceListResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files", params={"path": path})
|
||||
return WorkspaceListResult.model_validate(data)
|
||||
|
||||
def preview(self, session_id: str, path: str) -> WorkspacePreviewResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/preview", params={"path": path})
|
||||
return WorkspacePreviewResult.model_validate(data)
|
||||
|
||||
def download(self, session_id: str, path: str) -> WorkspaceDownloadResult:
|
||||
data = self._get(f"/workspaces/{session_id}/files/download", params={"path": path})
|
||||
encoded = data.get("content_base64")
|
||||
if not isinstance(encoded, str):
|
||||
raise AgentBackendHTTPError("agent backend download response missing content", status_code=502, detail=data)
|
||||
try:
|
||||
content = base64.b64decode(encoded, validate=True)
|
||||
except (binascii.Error, ValueError) as exc:
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend returned undecodable download content", status_code=502, detail=str(exc)
|
||||
) from exc
|
||||
size = data.get("size")
|
||||
return WorkspaceDownloadResult(
|
||||
path=str(data.get("path", path)),
|
||||
size=int(size) if isinstance(size, (int, float)) else len(content),
|
||||
truncated=bool(data.get("truncated")),
|
||||
content=content,
|
||||
)
|
||||
|
||||
def _get(self, route: str, *, params: dict[str, str]) -> dict[str, object]:
|
||||
url = f"{self._base_url}{route}"
|
||||
try:
|
||||
with httpx.Client(timeout=self._timeout, transport=self._transport, trust_env=False) as client:
|
||||
response = client.get(url, params=params)
|
||||
except httpx.HTTPError as exc:
|
||||
raise AgentBackendTransportError(f"failed to reach agent backend workspace endpoint: {exc}") from exc
|
||||
if response.status_code >= 400:
|
||||
detail: object
|
||||
try:
|
||||
detail = response.json().get("detail", response.text)
|
||||
except ValueError:
|
||||
detail = response.text
|
||||
raise AgentBackendHTTPError(
|
||||
f"agent backend workspace request failed ({response.status_code})",
|
||||
status_code=response.status_code,
|
||||
detail=detail,
|
||||
)
|
||||
body = response.json()
|
||||
if not isinstance(body, dict):
|
||||
raise AgentBackendHTTPError(
|
||||
"agent backend workspace response was not an object", status_code=502, detail=body
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WorkspaceDownloadResult",
|
||||
"WorkspaceFileEntry",
|
||||
"WorkspaceFilesBackendClient",
|
||||
"WorkspaceListResult",
|
||||
"WorkspacePreviewResult",
|
||||
]
|
||||
@ -3,6 +3,13 @@ CLI command modules extracted from `commands.py`.
|
||||
"""
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .data_migration import (
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
)
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -25,7 +32,12 @@ from .retention import (
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .system import (
|
||||
convert_to_agent_apps,
|
||||
fix_app_site_missing,
|
||||
reset_encrypt_key_pair,
|
||||
upgrade_db,
|
||||
)
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
@ -44,18 +56,24 @@ __all__ = [
|
||||
"clear_orphaned_file_records",
|
||||
"convert_to_agent_apps",
|
||||
"create_tenant",
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"export_migration_data",
|
||||
"export_migration_data_template",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"import_migration_data",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"migration_data_wizard",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
|
||||
179
api/commands/data_migrate.py
Normal file
179
api/commands/data_migrate.py
Normal file
@ -0,0 +1,179 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from services.legacy_model_type_migration import (
|
||||
VALID_TABLE_NAMES,
|
||||
LegacyModelTypeMigrationService,
|
||||
load_tenant_ids_from_file,
|
||||
)
|
||||
|
||||
_SUPPORTED_MODEL_TYPE_CHOICES = (
|
||||
ModelType.LLM.value,
|
||||
ModelType.TEXT_EMBEDDING.value,
|
||||
ModelType.RERANK.value,
|
||||
)
|
||||
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
|
||||
|
||||
|
||||
def _normalize_multi_value_option(
|
||||
values: tuple[str, ...],
|
||||
*,
|
||||
valid_values: tuple[str, ...],
|
||||
option_name: str,
|
||||
) -> tuple[str, ...]:
|
||||
normalized_values: list[str] = []
|
||||
seen_values: set[str] = set()
|
||||
|
||||
for value in values:
|
||||
for item in value.split(","):
|
||||
normalized_item = item.strip()
|
||||
if not normalized_item:
|
||||
continue
|
||||
if normalized_item not in valid_values:
|
||||
raise click.BadParameter(
|
||||
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
|
||||
param_hint=option_name,
|
||||
)
|
||||
if normalized_item in seen_values:
|
||||
continue
|
||||
seen_values.add(normalized_item)
|
||||
normalized_values.append(normalized_item)
|
||||
|
||||
return tuple(normalized_values)
|
||||
|
||||
|
||||
@click.group(
|
||||
"data-migrate",
|
||||
help="Online data migration commands.",
|
||||
)
|
||||
def data_migrate() -> None:
|
||||
"""Namespace for production data migration commands."""
|
||||
|
||||
|
||||
@click.command(
|
||||
"legacy-model-types",
|
||||
help=(
|
||||
"Migrate legacy provider model_type values to canonical values. "
|
||||
"Default is dry-run and emits JSON lines only. "
|
||||
"If --tables includes provider_model_credentials, the command may also update "
|
||||
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--apply",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Apply the migration. Default is dry-run.",
|
||||
)
|
||||
@click.option(
|
||||
"--tables",
|
||||
"tables",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Limit migration to specific tables. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: load_balancing_model_configs, provider_model_credentials, "
|
||||
"provider_model_settings, provider_models, tenant_default_models.\n\n"
|
||||
"When provider_model_credentials is selected, provider_models and "
|
||||
"load_balancing_model_configs may also be updated for credential reference rewrites.\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant tables are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-types",
|
||||
"model_types",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Canonical model types to migrate. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: llm,text-embedding,rerank\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant legacy model types are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-id-file",
|
||||
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
|
||||
help="Optional file containing tenant ids, one per line.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
|
||||
help=(
|
||||
"Optional file path for JSON lines event logs. Defaults to stdout.\n"
|
||||
"It's highly recommended to save the event logs to a file and preserve it for a period of time."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--concurrency",
|
||||
type=click.IntRange(min=1),
|
||||
default=_DEFAULT_CONCURRENCY,
|
||||
show_default=True,
|
||||
help="Number of tenant-level worker threads to run in parallel.",
|
||||
)
|
||||
def legacy_model_types(
|
||||
apply: bool,
|
||||
tables: tuple[str, ...],
|
||||
model_types: tuple[str, ...],
|
||||
tenant_id_file: str | None,
|
||||
output: Path | None,
|
||||
concurrency: int = _DEFAULT_CONCURRENCY,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate legacy provider-related model_type values and emit JSON lines events.
|
||||
"""
|
||||
|
||||
normalized_tables = _normalize_multi_value_option(
|
||||
tables,
|
||||
valid_values=VALID_TABLE_NAMES,
|
||||
option_name="--tables",
|
||||
)
|
||||
normalized_model_types = _normalize_multi_value_option(
|
||||
model_types,
|
||||
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
|
||||
option_name="--model-types",
|
||||
)
|
||||
selected_model_types = (
|
||||
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
|
||||
if normalized_model_types
|
||||
else (
|
||||
ModelType.LLM,
|
||||
ModelType.TEXT_EMBEDDING,
|
||||
ModelType.RERANK,
|
||||
)
|
||||
)
|
||||
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
|
||||
|
||||
output_context: AbstractContextManager[io.TextIOBase]
|
||||
if output is None:
|
||||
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
|
||||
else:
|
||||
try:
|
||||
output_context = output.open("w", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
|
||||
|
||||
with output_context as output_stream:
|
||||
LegacyModelTypeMigrationService(
|
||||
engine=db.engine,
|
||||
apply=apply,
|
||||
concurrency=concurrency,
|
||||
output=cast(io.TextIOBase, output_stream),
|
||||
tables=normalized_tables or None,
|
||||
model_types=selected_model_types,
|
||||
tenant_ids=tenant_ids,
|
||||
).migrate()
|
||||
|
||||
|
||||
data_migrate.add_command(legacy_model_types)
|
||||
754
api/commands/data_migration.py
Normal file
754
api/commands/data_migration.py
Normal file
@ -0,0 +1,754 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.model import App
|
||||
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
|
||||
from services.data_migration.entities import (
|
||||
DependencyKind,
|
||||
ImportOptions,
|
||||
MigrationDataError,
|
||||
ReportContext,
|
||||
ResourceReportItem,
|
||||
)
|
||||
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
|
||||
from services.data_migration.import_service import ImportRequest, MigrationImportService
|
||||
from services.data_migration.package_service import MigrationPackageService
|
||||
from services.data_migration.report_service import MigrationReportService
|
||||
|
||||
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
|
||||
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
|
||||
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
|
||||
WizardToolMap = dict[str, dict[str, str | None]]
|
||||
WizardToolSelection = dict[str, list[str]]
|
||||
|
||||
|
||||
def _scripted_export_template() -> dict[str, Any]:
|
||||
return {
|
||||
"source_tenant": {
|
||||
"mode": "single",
|
||||
"id": "",
|
||||
"name": "admin's Workspace",
|
||||
},
|
||||
"apps": {
|
||||
"modes": ["workflow", "advanced-chat"],
|
||||
"ids": [],
|
||||
"all": True,
|
||||
},
|
||||
"include_referenced_tools": True,
|
||||
"additional_tools": {
|
||||
"api_tools": [],
|
||||
"workflow_tools": [],
|
||||
"mcp_tools": [],
|
||||
},
|
||||
"include_secrets": False,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": False,
|
||||
"id_strategy": "preserve-id",
|
||||
"conflict_strategy": "fail",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to write the export config JSON template. Prints to stdout when omitted.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
|
||||
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
|
||||
if output_file is None:
|
||||
click.echo(template_json, nl=False)
|
||||
return
|
||||
path = Path(output_file)
|
||||
if path.exists() and not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
path.write_text(template_json)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to export config JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file), ("--output", output_file))
|
||||
assert input_file is not None
|
||||
assert output_file is not None
|
||||
raw_config = _load_json_object(input_file, "Export config")
|
||||
selection = ExportConfigParser().parse(raw_config)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
@click.command("import-app-migration", help="Import a versioned migration data package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
|
||||
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
|
||||
@click.option(
|
||||
"--id-strategy",
|
||||
default=None,
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
help="Override package ID strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--conflict-strategy",
|
||||
default=None,
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
help="Override package conflict strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
|
||||
default=None,
|
||||
help="Override package app API token creation behavior.",
|
||||
)
|
||||
def import_migration_data(
|
||||
input_file: str | None,
|
||||
target_tenant: str | None,
|
||||
operator_email: str | None,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file))
|
||||
assert input_file is not None
|
||||
package = MigrationPackageService().load_package(input_file)
|
||||
result = MigrationImportService().import_package(
|
||||
ImportRequest(
|
||||
package=package,
|
||||
cli_target_tenant=target_tenant,
|
||||
operator_email=operator_email,
|
||||
options_override=_build_options_override(
|
||||
package.metadata.import_options,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
create_app_api_token_on_import=create_app_api_token_on_import,
|
||||
),
|
||||
)
|
||||
)
|
||||
_render_report(result.report_items, context=result.report_context)
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
|
||||
normalized = raw.strip().lower()
|
||||
if normalized == "all":
|
||||
return values
|
||||
|
||||
selected: list[str] = []
|
||||
for part in raw.split(","):
|
||||
stripped = part.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
index = int(stripped)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
|
||||
if index < 1 or index > len(values):
|
||||
raise click.ClickException(f"Selection index out of range: {index}")
|
||||
selected.append(values[index - 1])
|
||||
return list(dict.fromkeys(selected))
|
||||
|
||||
|
||||
def _print_wizard_step(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"==== {title} ====")
|
||||
|
||||
|
||||
def _print_wizard_substep(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"-- {title} --")
|
||||
|
||||
|
||||
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
|
||||
def migration_data_wizard() -> None:
|
||||
try:
|
||||
tenant = _prompt_source_tenant()
|
||||
apps = _eligible_apps_for_tenant(tenant.id)
|
||||
app_ids = _prompt_app_ids(apps)
|
||||
_print_wizard_step("Referenced Tools")
|
||||
include_referenced_tools = click.confirm(
|
||||
"Automatically export tools referenced by selected apps? [y/n, default: y]",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
|
||||
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
|
||||
_print_auto_tools(auto_tools)
|
||||
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
|
||||
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
|
||||
_print_wizard_step("Output")
|
||||
output_file, overwrite = _prompt_output_file()
|
||||
|
||||
selection = ExportConfigParser().parse(
|
||||
{
|
||||
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
|
||||
"apps": {"ids": app_ids, "all": False},
|
||||
"include_referenced_tools": include_referenced_tools,
|
||||
"additional_tools": additional_tools,
|
||||
"include_secrets": include_secrets,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": create_tokens,
|
||||
"id_strategy": id_strategy,
|
||||
"conflict_strategy": conflict_strategy,
|
||||
},
|
||||
}
|
||||
)
|
||||
_confirm_wizard_summary(
|
||||
tenant_name=tenant.name,
|
||||
app_names=[app.name for app in apps if app.id in set(app_ids)],
|
||||
auto_tools=auto_tools,
|
||||
additional_tools=additional_tools,
|
||||
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
|
||||
include_referenced_tools=include_referenced_tools,
|
||||
include_secrets=include_secrets,
|
||||
create_tokens=create_tokens,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
output_file=output_file,
|
||||
)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_print_wizard_step("Report")
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def _load_json_object(path: str, label: str) -> dict[str, Any]:
|
||||
try:
|
||||
with Path(path).open(encoding="utf-8") as file:
|
||||
raw = json.load(file)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise MigrationDataError(f"{label} JSON must be an object.")
|
||||
return raw
|
||||
|
||||
|
||||
def _require_options(*options: tuple[str, object | None]) -> None:
|
||||
missing_options = [name for name, value in options if value is None]
|
||||
if missing_options:
|
||||
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
|
||||
|
||||
|
||||
def _build_options_override(
|
||||
package_options: ImportOptions,
|
||||
*,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> ImportOptions | None:
|
||||
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
|
||||
return None
|
||||
return ImportOptions.from_mapping(
|
||||
{
|
||||
"id_strategy": id_strategy or package_options.id_strategy,
|
||||
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
|
||||
"create_app_api_token_on_import": (
|
||||
create_app_api_token_on_import
|
||||
if create_app_api_token_on_import is not None
|
||||
else package_options.create_app_api_token_on_import
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _prompt_source_tenant() -> Tenant:
|
||||
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
|
||||
if not tenants:
|
||||
raise MigrationDataError("No tenants found.")
|
||||
|
||||
_print_wizard_step("Source Tenant")
|
||||
click.echo("Source tenants:")
|
||||
for index, tenant in enumerate(tenants, 1):
|
||||
click.echo(f"{index}. {tenant.name} ({tenant.id})")
|
||||
|
||||
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
|
||||
if tenant_index < 1 or tenant_index > len(tenants):
|
||||
raise click.ClickException(f"Selection index out of range: {tenant_index}")
|
||||
return tenants[tenant_index - 1]
|
||||
|
||||
|
||||
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
|
||||
return list(
|
||||
db.session.scalars(
|
||||
sa.select(App)
|
||||
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
|
||||
.order_by(App.name.asc())
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def _prompt_app_ids(apps: list[App]) -> list[str]:
|
||||
if not apps:
|
||||
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
|
||||
|
||||
_print_wizard_step("App Selection")
|
||||
click.echo("Currently supported app types: workflow and chatflow.")
|
||||
click.echo("Workflow/chatflow apps:")
|
||||
for index, app in enumerate(apps, 1):
|
||||
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
|
||||
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
|
||||
app_ids = parse_index_selection(
|
||||
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
|
||||
[app.id for app in apps],
|
||||
)
|
||||
selected_apps = [app for app in apps if app.id in set(app_ids)]
|
||||
click.echo("Selected apps:")
|
||||
for app in selected_apps:
|
||||
click.echo(f"- {app.name} ({app.id})")
|
||||
return app_ids
|
||||
|
||||
|
||||
def _prompt_import_options() -> tuple[bool, bool, str, str]:
|
||||
_print_wizard_step("Import Options")
|
||||
_print_wizard_substep("Secrets")
|
||||
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
|
||||
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
|
||||
click.echo("If you choose no, credentials are omitted or masked,")
|
||||
click.echo("and MCP providers are exported as dependency metadata only.")
|
||||
click.echo("Treat the output JSON as sensitive if you choose yes.")
|
||||
include_secrets = click.confirm(
|
||||
"Include secrets in output JSON? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("App API Tokens")
|
||||
click.echo("When enabled, import will create an app API token if the imported app has none,")
|
||||
click.echo("or reuse an existing app API token if one already exists.")
|
||||
create_tokens = click.confirm(
|
||||
"Create or reuse app API tokens during import? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("ID Strategy")
|
||||
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
|
||||
click.echo("or use target-generated IDs.")
|
||||
click.echo("preserve-id: keep source IDs where the target service supports it.")
|
||||
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
|
||||
id_strategy = click.prompt(
|
||||
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
default="preserve-id",
|
||||
show_default=True,
|
||||
)
|
||||
_print_wizard_substep("Conflict Strategy")
|
||||
click.echo("Conflict strategy controls what import does when a target resource already exists.")
|
||||
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
|
||||
click.echo("skip: keep the existing target resource and skip importing that resource.")
|
||||
click.echo("update: update the existing target resource in place.")
|
||||
conflict_strategy = click.prompt(
|
||||
"Import conflict strategy. Enter one of: fail, skip, update",
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
default="update",
|
||||
show_default=True,
|
||||
)
|
||||
return include_secrets, create_tokens, id_strategy, conflict_strategy
|
||||
|
||||
|
||||
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
|
||||
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
|
||||
if not include_referenced_tools:
|
||||
return auto_tools
|
||||
discovery_service = DependencyDiscoveryService()
|
||||
for app in apps:
|
||||
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
|
||||
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
|
||||
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
|
||||
for dependency in discovery_service.discover_from_dsl(dsl):
|
||||
if dependency.kind == DependencyKind.API_TOOL:
|
||||
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
|
||||
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
|
||||
dependency.provider_id
|
||||
)
|
||||
elif dependency.kind == DependencyKind.MCP_TOOL:
|
||||
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
return auto_tools
|
||||
|
||||
|
||||
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
|
||||
return {
|
||||
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
|
||||
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
|
||||
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [ApiToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(ApiToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [WorkflowToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(WorkflowToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [MCPToolProvider.name == name]
|
||||
if identifier:
|
||||
predicates.append(MCPToolProvider.server_identifier == identifier)
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(MCPToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_uuid_string(value: str | None) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
|
||||
_print_wizard_step("Automatically Discovered Tools")
|
||||
click.echo("Automatically discovered tools:")
|
||||
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
|
||||
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
|
||||
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
|
||||
|
||||
|
||||
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
|
||||
click.echo(label)
|
||||
if not values:
|
||||
click.echo("- none")
|
||||
return
|
||||
for name, identifier in sorted(values.items()):
|
||||
click.echo(f"- {_format_tool_name_id(name, identifier)}")
|
||||
|
||||
|
||||
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
|
||||
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
|
||||
_print_wizard_step("Additional Tools")
|
||||
if not click.confirm(
|
||||
"Export additional tools manually? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
):
|
||||
_print_final_tool_selection(auto_tools, selections, {})
|
||||
return selections
|
||||
manual_labels: dict[str, str] = {}
|
||||
api_tool_options = [
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["api_tools"] = _prompt_tool_category(
|
||||
"Custom API tools",
|
||||
api_tool_options,
|
||||
auto_tools=auto_tools["api_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
|
||||
workflow_tool_options = [
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["workflow_tools"] = _prompt_tool_category(
|
||||
"Workflow tools",
|
||||
workflow_tool_options,
|
||||
auto_tools=auto_tools["workflow_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
|
||||
mcp_tool_options = [
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["mcp_tools"] = _prompt_tool_category(
|
||||
"MCP tools",
|
||||
mcp_tool_options,
|
||||
auto_tools=auto_tools["mcp_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
|
||||
_print_final_tool_selection(auto_tools, selections, manual_labels)
|
||||
return selections
|
||||
|
||||
|
||||
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
if selected_tools["api_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id)
|
||||
.order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["api_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["workflow_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["workflow_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["mcp_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["mcp_tools"],
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
|
||||
selected = set(selected_values)
|
||||
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
|
||||
|
||||
|
||||
def _prompt_tool_category(
|
||||
label: str,
|
||||
options: list[tuple[str, str, str]],
|
||||
*,
|
||||
auto_tools: dict[str, str | None],
|
||||
) -> list[str]:
|
||||
if not options:
|
||||
click.echo(f"{label}: none")
|
||||
return []
|
||||
_print_wizard_step(label)
|
||||
for index, (value, name, detail) in enumerate(options, 1):
|
||||
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
|
||||
click.echo(f"{index}. {marker} {name} ({detail})")
|
||||
raw = click.prompt(
|
||||
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
|
||||
default="",
|
||||
show_default=cast(Any, "empty"),
|
||||
)
|
||||
if not raw.strip():
|
||||
return []
|
||||
return parse_index_selection(raw, [value for value, _, _ in options])
|
||||
|
||||
|
||||
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
|
||||
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
|
||||
|
||||
|
||||
def _print_final_tool_selection(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
_print_wizard_step("Final Tool Selection")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
|
||||
|
||||
def _print_tool_selection_body(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo("Final tools to export:")
|
||||
_print_final_tool_category(
|
||||
"Custom API tools",
|
||||
auto_tools["api_tools"],
|
||||
additional_tools["api_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category(
|
||||
"Workflow tools",
|
||||
auto_tools["workflow_tools"],
|
||||
additional_tools["workflow_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
|
||||
|
||||
|
||||
def _print_final_tool_category(
|
||||
label: str,
|
||||
auto_tools: dict[str, str | None],
|
||||
manual_values: list[str],
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo(label)
|
||||
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
|
||||
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
|
||||
lines.extend(
|
||||
f"- [manual] {manual_labels.get(value, value)}"
|
||||
for value in manual_values
|
||||
if value not in auto_tools and value not in auto_identifiers
|
||||
)
|
||||
if not lines:
|
||||
click.echo("- none")
|
||||
return
|
||||
for line in lines:
|
||||
click.echo(line)
|
||||
|
||||
|
||||
def _format_tool_name_id(name: str, identifier: str | None) -> str:
|
||||
if identifier and identifier != name:
|
||||
return f"{name}: {identifier}"
|
||||
return name
|
||||
|
||||
|
||||
def _confirm_wizard_summary(
|
||||
*,
|
||||
tenant_name: str,
|
||||
app_names: list[str],
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
include_referenced_tools: bool,
|
||||
include_secrets: bool,
|
||||
create_tokens: bool,
|
||||
id_strategy: str,
|
||||
conflict_strategy: str,
|
||||
output_file: str,
|
||||
) -> None:
|
||||
_print_wizard_step("Summary")
|
||||
click.echo("Migration export summary:")
|
||||
click.echo(f"source tenant: {tenant_name}")
|
||||
click.echo(f"selected apps: {len(app_names)}")
|
||||
for app_name in app_names:
|
||||
click.echo(f"- {app_name}")
|
||||
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
click.echo(f"include secrets: {str(include_secrets).lower()}")
|
||||
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
|
||||
click.echo(f"id strategy: {id_strategy}")
|
||||
click.echo(f"conflict strategy: {conflict_strategy}")
|
||||
click.echo(f"output path: {output_file}")
|
||||
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _prompt_output_file() -> tuple[str, bool]:
|
||||
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
output_file = click.prompt("Output path", default=default_output, show_default=True)
|
||||
if output_file.lower() in {"y", "yes", "n", "no"}:
|
||||
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
|
||||
overwrite = False
|
||||
if Path(output_file).exists():
|
||||
overwrite = click.confirm(
|
||||
"Output file exists. Overwrite? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
return output_file, overwrite
|
||||
|
||||
|
||||
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
|
||||
if context is None:
|
||||
return ReportContext(output_path=output_path)
|
||||
return ReportContext(
|
||||
output_path=output_path,
|
||||
source_scope=context.source_scope,
|
||||
selected_app_count=context.selected_app_count,
|
||||
include_secrets=context.include_secrets,
|
||||
target_tenant=context.target_tenant,
|
||||
operator_email=context.operator_email,
|
||||
app_api_tokens_created=context.app_api_tokens_created,
|
||||
app_api_tokens_reused=context.app_api_tokens_reused,
|
||||
id_mapping_count=context.id_mapping_count,
|
||||
id_mappings=context.id_mappings,
|
||||
)
|
||||
|
||||
|
||||
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
|
||||
for line in MigrationReportService().render(report_items, context=context):
|
||||
click.echo(line)
|
||||
@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation datas to target vector database .
|
||||
Migrate annotation data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database datas to target vector database .
|
||||
Migrate vector database data to target vector database.
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
|
||||
@ -29,6 +29,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
current_state = self.current_state
|
||||
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
|
||||
|
||||
@ -21,3 +21,13 @@ class AgentBackendConfig(BaseSettings):
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
|
||||
AGENT_SHELL_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Inject the dify.shell layer (sandboxed bash workspace) into Agent runs. "
|
||||
"Requires the agent backend to be wired with a shellctl entrypoint; keep it "
|
||||
"off until shellctl is deployed, otherwise every agent run that includes the "
|
||||
"shell layer will fail."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
@ -41,3 +41,21 @@ class MilvusConfig(BaseSettings):
|
||||
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
|
||||
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_PEM_PATH: str | None = Field(
|
||||
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
|
||||
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_NAME: str | None = Field(
|
||||
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
|
||||
"Required when MILVUS_SERVER_PEM_PATH is set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -81,4 +81,15 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new Agent App type). The runtime model / prompt / tools
|
||||
# come from the bound Agent Soul snapshot, so no model_config is seeded in the
|
||||
# template; create_app still creates a model-less app_model_config row to hold
|
||||
# app-level presentation features (opener, follow-up, citations, ...).
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,10 +1,40 @@
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
action: str
|
||||
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@ These helpers keep that translation centralized so models registered through
|
||||
`register_schema_models` emit resolvable Swagger 2.0 references.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
from typing import Any, Literal, NotRequired, Protocol, TypedDict
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
@ -35,6 +36,14 @@ QueryParamDoc = TypedDict(
|
||||
},
|
||||
)
|
||||
|
||||
JsonResponseWithStatus = tuple[dict[str, Any], int]
|
||||
|
||||
|
||||
class QueryArgs(Protocol):
|
||||
def to_dict(self, flat: bool = True) -> dict[str, str]: ...
|
||||
|
||||
def getlist(self, key: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _register_json_schema(namespace: Namespace, name: str, schema: dict) -> None:
|
||||
"""Register a JSON schema and promote any nested Pydantic `$defs`."""
|
||||
@ -167,6 +176,58 @@ def query_params_from_model(model: type[BaseModel]) -> dict[str, QueryParamDoc]:
|
||||
return params
|
||||
|
||||
|
||||
def query_params_from_request[ModelT: BaseModel](
|
||||
model: type[ModelT],
|
||||
*,
|
||||
list_fields: Iterable[str] = (),
|
||||
args: QueryArgs | None = None,
|
||||
use_defaults_for_malformed_ints: bool = False,
|
||||
) -> ModelT:
|
||||
"""Validate query args with Pydantic while preserving Flask query parsing behavior.
|
||||
|
||||
Repeated params need explicit ``getlist()`` handling because Werkzeug's
|
||||
``to_dict()`` keeps only one value. For malformed scalar integers, Flask's
|
||||
For endpoints migrated from ``request.args.get(..., type=int, default=...)``,
|
||||
set ``use_defaults_for_malformed_ints`` to preserve Flask's fallback to
|
||||
defaults for malformed optional integer params.
|
||||
"""
|
||||
|
||||
query_args = args or request.args
|
||||
params: dict[str, Any] = query_args.to_dict()
|
||||
for field_name in list_fields:
|
||||
params[field_name] = query_args.getlist(field_name)
|
||||
|
||||
if use_defaults_for_malformed_ints:
|
||||
_drop_malformed_defaulted_integer_params(model, params)
|
||||
return model.model_validate(params)
|
||||
|
||||
|
||||
def _drop_malformed_defaulted_integer_params(model: type[BaseModel], params: dict[str, Any]) -> None:
|
||||
properties = model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0).get("properties", {})
|
||||
if not isinstance(properties, Mapping):
|
||||
return
|
||||
|
||||
for name, value in list(params.items()):
|
||||
if not isinstance(value, str):
|
||||
continue
|
||||
|
||||
field = model.model_fields.get(name)
|
||||
if field is None or field.is_required():
|
||||
continue
|
||||
|
||||
property_schema = properties.get(name)
|
||||
if not isinstance(property_schema, Mapping):
|
||||
continue
|
||||
|
||||
if _nullable_property_schema(property_schema).get("type") != "integer":
|
||||
continue
|
||||
|
||||
try:
|
||||
int(value)
|
||||
except ValueError:
|
||||
params.pop(name)
|
||||
|
||||
|
||||
def _query_param_from_property(property_schema: Mapping[str, Any], *, required: bool) -> QueryParamDoc:
|
||||
param_schema = _nullable_property_schema(property_schema)
|
||||
param_doc: QueryParamDoc = {"in": "query", "required": required}
|
||||
@ -239,6 +300,7 @@ __all__ = [
|
||||
"DEFAULT_REF_TEMPLATE_SWAGGER_2_0",
|
||||
"get_or_create_model",
|
||||
"query_params_from_model",
|
||||
"query_params_from_request",
|
||||
"register_enum_models",
|
||||
"register_response_schema_model",
|
||||
"register_response_schema_models",
|
||||
|
||||
@ -51,6 +51,9 @@ from .agent import roster as agent_roster
|
||||
from .app import (
|
||||
advanced_prompt_template,
|
||||
agent,
|
||||
agent_app_access,
|
||||
agent_app_feature,
|
||||
agent_app_workspace,
|
||||
annotation,
|
||||
app,
|
||||
audio,
|
||||
@ -146,6 +149,9 @@ __all__ = [
|
||||
"activate",
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_app_access",
|
||||
"agent_app_feature",
|
||||
"agent_app_workspace",
|
||||
"agent_composer",
|
||||
"agent_providers",
|
||||
"agent_roster",
|
||||
|
||||
@ -1,153 +1,229 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from fields.agent_fields import (
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
|
||||
register_schema_models(console_ns, ComposerSavePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerValidateResponse,
|
||||
WorkflowAgentComposerResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer")
|
||||
class WorkflowAgentComposerApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer state", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/validate")
|
||||
class WorkflowAgentComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
def post(self, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/candidates")
|
||||
class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_workflow_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/impact")
|
||||
class WorkflowAgentComposerImpactApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Workflow agent composer impact", console_ns.models[AgentComposerImpactResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
if not current_snapshot_id:
|
||||
return {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
return AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse, {"current_snapshot_id": None, "workflow_node_count": 0, "bindings": []}
|
||||
)
|
||||
return dump_response(
|
||||
AgentComposerImpactResponse,
|
||||
AgentComposerService.calculate_impact(tenant_id=tenant_id, current_snapshot_id=current_snapshot_id),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/agent-composer/save-to-roster")
|
||||
class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Workflow agent composer saved to roster", console_ns.models[WorkflowAgentComposerResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str, app_model: App, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
WorkflowAgentComposerResponse,
|
||||
AgentComposerService.save_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
node_id=node_id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer")
|
||||
class AgentAppComposerApi(Resource):
|
||||
@console_ns.response(200, "Agent app composer state", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(200, "Agent app composer saved", console_ns.models[AgentAppComposerResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def put(self, tenant_id: str, account_id: str, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account.id,
|
||||
payload=payload,
|
||||
return dump_response(
|
||||
AgentAppComposerResponse,
|
||||
AgentComposerService.save_agent_app_composer(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
account_id=account_id,
|
||||
payload=payload,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/validate")
|
||||
class AgentAppComposerValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ComposerSavePayload.__name__])
|
||||
@console_ns.response(
|
||||
200, "Agent app composer validation result", console_ns.models[AgentComposerValidateResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
return dump_response(AgentComposerValidateResponse, {"result": "success", "errors": []})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-composer/candidates")
|
||||
class AgentAppComposerCandidatesApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Agent app composer candidates", console_ns.models[AgentComposerCandidatesResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
def get(self, app_model: App):
|
||||
return dump_response(
|
||||
AgentComposerCandidatesResponse,
|
||||
AgentComposerService.get_agent_app_candidates(app_id=app_model.id),
|
||||
)
|
||||
|
||||
@ -4,11 +4,25 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.agent_fields import (
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
from services.entities.agent_entities import RosterAgentCreatePayload, RosterAgentUpdatePayload, RosterListQuery
|
||||
|
||||
@ -29,6 +43,14 @@ register_schema_models(
|
||||
RosterAgentUpdatePayload,
|
||||
RosterListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
AgentConfigSnapshotListResponse,
|
||||
AgentInviteOptionsResponse,
|
||||
AgentRosterListResponse,
|
||||
AgentRosterResponse,
|
||||
)
|
||||
|
||||
|
||||
def _agent_roster_service() -> AgentRosterService:
|
||||
@ -37,96 +59,130 @@ def _agent_roster_service() -> AgentRosterService:
|
||||
|
||||
@console_ns.route("/agents")
|
||||
class AgentRosterListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(RosterListQuery))
|
||||
@console_ns.response(200, "Agent roster list", console_ns.models[AgentRosterListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = RosterListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
return dump_response(
|
||||
AgentRosterListResponse,
|
||||
_agent_roster_service().list_roster_agents(
|
||||
tenant_id=tenant_id, page=query.page, limit=query.limit, keyword=query.keyword
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Agent created", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, account_id: str):
|
||||
payload = RosterAgentCreatePayload.model_validate(console_ns.payload or {})
|
||||
service = _agent_roster_service()
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account.id, payload=payload)
|
||||
return service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id), 201
|
||||
agent = service.create_roster_agent(tenant_id=tenant_id, account_id=account_id, payload=payload)
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
service.get_roster_agent_detail(tenant_id=tenant_id, agent_id=agent.id),
|
||||
), 201
|
||||
|
||||
|
||||
@console_ns.route("/agents/invite-options")
|
||||
class AgentInviteOptionsApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(AgentInviteOptionsQuery))
|
||||
@console_ns.response(200, "Agent invite options", console_ns.models[AgentInviteOptionsResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
query = AgentInviteOptionsQuery.model_validate(request.args.to_dict(flat=True))
|
||||
return _agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
return dump_response(
|
||||
AgentInviteOptionsResponse,
|
||||
_agent_roster_service().list_invite_options(
|
||||
tenant_id=tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
app_id=query.app_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>")
|
||||
class AgentRosterDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent detail", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id, payload=payload
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id)),
|
||||
)
|
||||
|
||||
@console_ns.expect(console_ns.models[RosterAgentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Agent updated", console_ns.models[AgentRosterResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return dump_response(
|
||||
AgentRosterResponse,
|
||||
_agent_roster_service().update_roster_agent(
|
||||
tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id, payload=payload
|
||||
),
|
||||
)
|
||||
|
||||
@console_ns.response(204, "Agent archived")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, account_id: str, agent_id: UUID):
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions")
|
||||
class AgentRosterVersionsApi(Resource):
|
||||
@console_ns.response(200, "Agent versions", console_ns.models[AgentConfigSnapshotListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotListResponse,
|
||||
{"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/agents/<uuid:agent_id>/versions/<uuid:version_id>")
|
||||
class AgentRosterVersionDetailApi(Resource):
|
||||
@console_ns.response(200, "Agent version detail", console_ns.models[AgentConfigSnapshotDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, agent_id: UUID, version_id: UUID):
|
||||
return dump_response(
|
||||
AgentConfigSnapshotDetailResponse,
|
||||
_agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=str(agent_id),
|
||||
version_id=str(version_id),
|
||||
),
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class AgentLogApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
59
api/controllers/console/app/agent_app_access.py
Normal file
59
api/controllers/console/app/agent_app_access.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Agent App access & sharing endpoints (read-only workflow references).
|
||||
|
||||
An Agent App is backed by a roster Agent that workflow Agent nodes may also
|
||||
reference. This exposes the read-only "Workflow access" surface from the PRD:
|
||||
which workflow apps use this Agent, without leaking the workflows' internals.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent.roster_service import AgentRosterService
|
||||
|
||||
|
||||
class AgentReferencingWorkflowResponse(ResponseModel):
|
||||
app_id: str
|
||||
app_name: str
|
||||
app_mode: str
|
||||
workflow_id: str
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentReferencingWorkflowsResponse(ResponseModel):
|
||||
data: list[AgentReferencingWorkflowResponse] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, AgentReferencingWorkflowsResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-referencing-workflows")
|
||||
class AgentAppReferencingWorkflowsResource(Resource):
|
||||
@console_ns.doc("list_agent_app_referencing_workflows")
|
||||
@console_ns.doc(description="List workflow apps that reference this Agent App's bound Agent (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Referencing workflows listed successfully",
|
||||
console_ns.models[AgentReferencingWorkflowsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, app_model: App):
|
||||
workflows = AgentRosterService(db.session).list_workflows_referencing_app_agent(
|
||||
tenant_id=tenant_id, app_id=app_model.id
|
||||
)
|
||||
return AgentReferencingWorkflowsResponse(
|
||||
data=[AgentReferencingWorkflowResponse.model_validate(workflow) for workflow in workflows]
|
||||
).model_dump(mode="json")
|
||||
93
api/controllers/console/app/agent_app_feature.py
Normal file
93
api/controllers/console/app/agent_app_feature.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Agent App presentation-feature configuration endpoint.
|
||||
|
||||
The new Agent App type keeps model / prompt / tools in its bound Agent Soul, so
|
||||
the legacy ``/model-config`` surface (which writes model, prompt and agent tool
|
||||
config) is the wrong place to configure its app-level presentation features.
|
||||
This endpoint exposes only the PRD "Misc Legacy" feature subset — conversation
|
||||
opener, follow-up suggestions, citations, content moderation and speech — and
|
||||
persists them onto the app's ``app_model_config`` without touching anything the
|
||||
Soul owns.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.agent_config_entities import (
|
||||
AgentFeatureToggleConfig,
|
||||
AgentSensitiveWordAvoidanceFeatureConfig,
|
||||
AgentSuggestedQuestionsAfterAnswerFeatureConfig,
|
||||
AgentTextToSpeechFeatureConfig,
|
||||
)
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_feature_service import AgentAppFeatureConfigService
|
||||
|
||||
|
||||
class AgentAppFeaturesPayload(BaseModel):
|
||||
"""Presentation features configurable on an Agent App.
|
||||
|
||||
All fields are optional; an omitted field is reset to its disabled/empty
|
||||
default (the config form sends the full desired feature state on save).
|
||||
"""
|
||||
|
||||
opening_statement: str | None = Field(default=None, description="Conversation opener shown before the first turn")
|
||||
suggested_questions: list[str] | None = Field(
|
||||
default=None, description="Preset questions shown alongside the opener"
|
||||
)
|
||||
suggested_questions_after_answer: AgentSuggestedQuestionsAfterAnswerFeatureConfig | None = Field(
|
||||
default=None, description="Follow-up suggestions config, e.g. {'enabled': true}"
|
||||
)
|
||||
speech_to_text: AgentFeatureToggleConfig | None = Field(default=None, description="Speech-to-text config")
|
||||
text_to_speech: AgentTextToSpeechFeatureConfig | None = Field(default=None, description="Text-to-speech config")
|
||||
retriever_resource: AgentFeatureToggleConfig | None = Field(
|
||||
default=None, description="Citations / attributions config, e.g. {'enabled': true}"
|
||||
)
|
||||
sensitive_word_avoidance: AgentSensitiveWordAvoidanceFeatureConfig | None = Field(
|
||||
default=None, description="Content moderation config"
|
||||
)
|
||||
|
||||
|
||||
register_schema_models(console_ns, AgentAppFeaturesPayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-features")
|
||||
class AgentAppFeatureConfigResource(Resource):
|
||||
@console_ns.doc("update_agent_app_features")
|
||||
@console_ns.doc(description="Update an Agent App's presentation features (opener, follow-up, citations, ...)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AgentAppFeaturesPayload.__name__])
|
||||
@console_ns.response(200, "Features updated successfully", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AgentAppFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
new_app_model_config = AgentAppFeatureConfigService.update_features(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
config=args.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"})
|
||||
319
api/controllers/console/app/agent_app_workspace.py
Normal file
319
api/controllers/console/app/agent_app_workspace.py
Normal file
@ -0,0 +1,319 @@
|
||||
"""Agent App sandbox file-system inspector (read-only).
|
||||
|
||||
Exposes the PRD "rc1-like sandbox file system, downloadable not editable" view
|
||||
for an Agent App conversation: list a directory, preview a file, or download a
|
||||
file from the conversation's shell-layer workspace. The API never touches
|
||||
shellctl directly — it resolves the conversation's sandbox ``session_id`` from
|
||||
the stored session snapshot and proxies to the agent backend's read-only
|
||||
workspace endpoints.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from clients.agent_backend.errors import AgentBackendHTTPError, AgentBackendTransportError
|
||||
from clients.agent_backend.workspace_files_client import WorkspaceDownloadResult
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App, AppMode
|
||||
from services.agent_app_workspace_service import (
|
||||
AgentAppWorkspaceService,
|
||||
AgentWorkspaceInspectorError,
|
||||
WorkflowAgentWorkspaceService,
|
||||
)
|
||||
|
||||
|
||||
class _WorkspaceFileDownloadField(fields.Raw):
|
||||
__schema_type__ = "string"
|
||||
__schema_format__ = "binary"
|
||||
|
||||
|
||||
class AgentWorkspaceListQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class AgentWorkspaceFileQuery(BaseModel):
|
||||
conversation_id: str = Field(min_length=1, description="Agent App conversation ID")
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceListQuery(BaseModel):
|
||||
path: str = Field(default=".", description="Directory path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAgentWorkspaceFileQuery(BaseModel):
|
||||
path: str = Field(min_length=1, description="File path relative to the sandbox workspace")
|
||||
node_execution_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional workflow node execution ID. When omitted, the latest active session for the node is used."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceFileEntryResponse(ResponseModel):
|
||||
name: str
|
||||
type: Literal["file", "dir", "symlink"]
|
||||
size: int
|
||||
mtime: int
|
||||
|
||||
|
||||
class WorkspaceListResponse(ResponseModel):
|
||||
path: str
|
||||
entries: list[WorkspaceFileEntryResponse] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
|
||||
|
||||
class WorkspacePreviewResponse(ResponseModel):
|
||||
path: str
|
||||
size: int
|
||||
truncated: bool
|
||||
binary: bool
|
||||
text: str | None = None
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, WorkspaceListResponse)
|
||||
register_response_schema_models(console_ns, WorkspacePreviewResponse)
|
||||
|
||||
|
||||
def _handle(exc: Exception) -> tuple[dict[str, object], int]:
|
||||
if isinstance(exc, AgentWorkspaceInspectorError):
|
||||
return {"code": exc.code, "message": exc.message}, exc.status_code
|
||||
if isinstance(exc, AgentBackendHTTPError):
|
||||
detail = exc.detail
|
||||
if isinstance(detail, dict):
|
||||
return {
|
||||
"code": detail.get("code", "agent_backend_error"),
|
||||
"message": detail.get("message", str(exc)),
|
||||
}, exc.status_code
|
||||
return {"code": "agent_backend_error", "message": str(detail)}, exc.status_code
|
||||
if isinstance(exc, AgentBackendTransportError):
|
||||
return {"code": "agent_backend_unreachable", "message": str(exc)}, 502
|
||||
raise exc
|
||||
|
||||
|
||||
def _download_response(result: WorkspaceDownloadResult) -> Response | tuple[dict[str, object], int]:
|
||||
if result.truncated:
|
||||
return {
|
||||
"code": "workspace_file_too_large",
|
||||
"message": (
|
||||
"file exceeds the workspace download limit; use preview for partial text or download a smaller file"
|
||||
),
|
||||
"size": result.size,
|
||||
}, 413
|
||||
filename = result.path.rsplit("/", 1)[-1] or "download"
|
||||
return Response(
|
||||
result.content,
|
||||
mimetype="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(len(result.content)),
|
||||
"X-Workspace-File-Size": str(result.size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files")
|
||||
class AgentAppWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_agent_app_workspace_files")
|
||||
@console_ns.doc(description="List a directory in an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceListQuery)})
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceListQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/preview")
|
||||
class AgentAppWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in an Agent App conversation's sandbox workspace")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent-workspace/files/download")
|
||||
class AgentAppWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_agent_app_workspace_file")
|
||||
@console_ns.doc(description="Download a file from an Agent App conversation's sandbox workspace (read-only)")
|
||||
@console_ns.doc(params={"app_id": "Application ID", **query_params_from_model(AgentWorkspaceFileQuery)})
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT])
|
||||
def get(self, app_model: App):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(AgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = AgentAppWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
conversation_id=query.conversation_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files"
|
||||
)
|
||||
class WorkflowAgentWorkspaceListResource(Resource):
|
||||
@console_ns.doc("list_workflow_agent_workspace_files")
|
||||
@console_ns.doc(description="List a directory in a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceListQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Listing returned", console_ns.models[WorkspaceListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceListQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/preview"
|
||||
)
|
||||
class WorkflowAgentWorkspacePreviewResource(Resource):
|
||||
@console_ns.doc("preview_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Preview a text/binary file in a Workflow Agent node's sandbox workspace")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Preview returned", console_ns.models[WorkspacePreviewResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().preview(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return result.model_dump()
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflow-runs/<uuid:workflow_run_id>/agent-nodes/<string:node_id>/workspace/files/download"
|
||||
)
|
||||
class WorkflowAgentWorkspaceDownloadResource(Resource):
|
||||
@console_ns.doc("download_workflow_agent_workspace_file")
|
||||
@console_ns.doc(description="Download a file from a Workflow Agent node's sandbox workspace (read-only)")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"workflow_run_id": "Workflow run ID",
|
||||
"node_id": "Workflow Agent node ID",
|
||||
**query_params_from_model(WorkflowAgentWorkspaceFileQuery),
|
||||
}
|
||||
)
|
||||
@console_ns.doc(produces=["application/octet-stream"])
|
||||
@console_ns.response(200, "File bytes", _WorkspaceFileDownloadField)
|
||||
@console_ns.response(413, "File exceeds the workspace download limit")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, workflow_run_id: UUID, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
query = query_params_from_request(WorkflowAgentWorkspaceFileQuery)
|
||||
try:
|
||||
result = WorkflowAgentWorkspaceService().download(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
node_id=node_id,
|
||||
node_execution_id=query.node_execution_id,
|
||||
path=query.path,
|
||||
)
|
||||
except Exception as exc: # normalized to an HTTP response below
|
||||
return _handle(exc)
|
||||
return _download_response(result)
|
||||
@ -25,6 +25,9 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
@ -34,8 +37,8 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from libs.helper import build_icon_url, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from libs.login import login_required
|
||||
from models import Account, App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppListParams, AppService, CreateAppParams
|
||||
@ -55,7 +58,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@ -66,7 +69,7 @@ _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@ -115,7 +118,9 @@ def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str,
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "agent", "advanced-chat", "workflow", "completion"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
@ -393,6 +398,8 @@ class AppDetailWithSite(AppDetail):
|
||||
max_active_requests: int | None = None
|
||||
deleted_tools: list[DeletedTool] = Field(default_factory=list)
|
||||
site: Site | None = None
|
||||
# For Agent App type: the roster Agent backing this app (None otherwise).
|
||||
bound_agent_id: str | None = None
|
||||
|
||||
@computed_field(return_type=str | None) # type: ignore
|
||||
@property
|
||||
@ -467,10 +474,11 @@ class AppListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
@with_session(write=False)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user_id: str, session: Session):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
params = AppListParams(
|
||||
page=args.page,
|
||||
@ -483,7 +491,7 @@ class AppListApi(Resource):
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, params)
|
||||
app_pagination = app_service.get_paginate_apps(current_user_id, current_tenant_id, params)
|
||||
if not app_pagination:
|
||||
empty = AppPagination(page=args.page, limit=args.limit, total=0, has_more=False, data=[])
|
||||
return empty.model_dump(mode="json"), 200
|
||||
@ -504,7 +512,7 @@ class AppListApi(Resource):
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
@ -543,9 +551,10 @@ class AppListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
params = CreateAppParams(
|
||||
name=args.name,
|
||||
@ -573,7 +582,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
|
||||
@ -581,7 +590,7 @@ class AppApi(Resource):
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
@ -598,7 +607,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
def put(self, app_model: App):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
@ -627,7 +636,7 @@ class AppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model):
|
||||
def delete(self, app_model: App):
|
||||
"""Delete app"""
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
@ -648,11 +657,10 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -709,7 +717,7 @@ class AppExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -731,7 +739,8 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -739,13 +748,11 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
redirect_url = get_redirect_url(current_user_id, claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
@ -762,7 +769,7 @@ class AppNameApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -784,7 +791,7 @@ class AppIconApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
@ -811,7 +818,7 @@ class AppSiteStatus(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -833,7 +840,7 @@ class AppApiStatus(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -874,7 +881,7 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
|
||||
@ -9,9 +9,11 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, Import
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -48,9 +50,9 @@ class AppImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
@ -97,10 +99,9 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
|
||||
@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
@ -171,7 +171,7 @@ class TextModesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
@ -19,7 +19,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -33,7 +38,7 @@ from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -41,9 +46,24 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_debugger_chat_streaming(
|
||||
*, app_mode: AppMode, response_mode: str, response_mode_provided: bool = True
|
||||
) -> bool:
|
||||
"""Agent App runtime is SSE-only until backend blocking runs are supported."""
|
||||
if app_mode != AppMode.AGENT:
|
||||
return response_mode != "blocking"
|
||||
if response_mode_provided and response_mode == "blocking":
|
||||
raise BadRequest("Agent App only supports streaming response mode.")
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
# Agent Apps (AppMode.AGENT) derive their model + prompt from the bound Agent
|
||||
# Soul, so no override ``model_config`` is sent; chat / agent-chat / completion
|
||||
# debugging still pass it. Optional here, required in practice by those modes
|
||||
# downstream when their config is built from args.
|
||||
model_config_data: dict[str, Any] = Field(default_factory=dict, alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
@ -84,7 +104,7 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -131,14 +151,13 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -157,13 +176,20 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
def post(self, app_model: App):
|
||||
raw_payload = console_ns.payload or {}
|
||||
args_model = ChatMessagePayload.model_validate(raw_payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = _resolve_debugger_chat_streaming(
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
response_mode=args_model.response_mode,
|
||||
response_mode_provided=isinstance(raw_payload, dict) and "response_mode" in raw_payload,
|
||||
)
|
||||
if AppMode.value_of(app_model.mode) == AppMode.AGENT:
|
||||
args["response_mode"] = "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -211,15 +237,14 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
|
||||
@ -12,7 +12,12 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
@ -31,9 +36,10 @@ from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from models.account import Account
|
||||
from models.model import App, AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
@ -93,8 +99,8 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
@ -165,10 +171,11 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,8 +189,8 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -205,10 +212,10 @@ class ChatConversationApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
subquery = (
|
||||
@ -316,12 +323,13 @@ class ChatConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(current_user, app_model, conversation_id_str), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -332,11 +340,11 @@ class ChatConversationDetailApi(Resource):
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, app_model: App, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
try:
|
||||
@ -347,8 +355,7 @@ class ChatConversationDetailApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def _get_conversation(current_user: Account, app_model, conversation_id):
|
||||
conversation = db.session.scalar(
|
||||
sa.select(Conversation).where(Conversation.id == conversation_id, Conversation.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
@ -94,7 +94,7 @@ class ConversationVariablesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
stmt = (
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_enum_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -11,7 +12,8 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.app.wraps import with_session
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -19,10 +21,9 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
@ -64,9 +65,9 @@ class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
@ -93,9 +94,9 @@ class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
@ -125,9 +126,9 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
@ -157,9 +158,10 @@ class InstructionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
@with_session(write=False)
|
||||
def post(self, session: Session, current_tenant_id: str):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
@ -168,10 +170,10 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.get(App, args.flow_id)
|
||||
app = session.get(App, args.flow_id)
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app, session=session)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
|
||||
@ -11,13 +11,18 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import AppMCPServer
|
||||
from models.model import App, AppMCPServer
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@ -73,7 +78,7 @@ class AppMCPServerController(Resource):
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
if server is None:
|
||||
return {}
|
||||
@ -92,8 +97,8 @@ class AppMCPServerController(Resource):
|
||||
@login_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = payload.description
|
||||
@ -127,7 +132,7 @@ class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def put(self, app_model):
|
||||
def put(self, app_model: App):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
@ -163,8 +168,8 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, server_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, server_id: UUID):
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
|
||||
@ -25,6 +25,7 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
@ -43,9 +44,10 @@ from fields.conversation_fields import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
@ -178,9 +180,9 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = db.session.scalar(
|
||||
@ -257,9 +259,8 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
|
||||
message_id = str(args.message_id)
|
||||
@ -314,7 +315,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
@ -336,9 +337,9 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
try:
|
||||
@ -379,7 +380,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
def get(self, app_model: App):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
@ -417,7 +418,7 @@ class MessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model, message_id: UUID):
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
|
||||
@ -8,15 +8,21 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@ -52,9 +58,10 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user_id: str, app_model: App):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -64,8 +71,8 @@ class ModelConfigResource(Resource):
|
||||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
created_by=current_user_id,
|
||||
updated_by=current_user_id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
@ -90,7 +97,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -130,7 +137,7 @@ class ModelConfigResource(Resource):
|
||||
tenant_id=current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
@ -167,7 +174,7 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_by = current_user_id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -14,12 +14,15 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Site
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
@ -84,9 +87,9 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -133,8 +136,8 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
if not site:
|
||||
|
||||
@ -8,13 +8,15 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import AppMode
|
||||
from models.account import Account
|
||||
from models.model import App
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
@ -47,9 +49,8 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -61,8 +62,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -104,9 +109,8 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -118,8 +122,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -160,9 +168,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -174,8 +181,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -217,9 +228,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -232,8 +242,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -276,10 +290,9 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
@ -299,8 +312,12 @@ FROM
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -353,9 +370,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
@ -371,8 +387,12 @@ LEFT JOIN
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -419,9 +439,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -433,8 +452,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -476,8 +499,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
@ -492,8 +515,12 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Concatenate, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -83,13 +83,14 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
match value:
|
||||
case FileSegment():
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
case ArrayFileSegment():
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
@ -213,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -230,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -6,12 +6,13 @@ from sqlalchemy.orm import sessionmaker
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@ -46,9 +47,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -86,9 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -126,9 +125,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
@ -166,9 +164,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account, app_model: App):
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
assert account.timezone is not None
|
||||
|
||||
@ -32,11 +32,11 @@ from controllers.console.wraps import (
|
||||
decrypt_password_field,
|
||||
email_password_login_enabled,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
@ -46,6 +46,7 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
@ -172,9 +173,8 @@ class LoginApi(Resource):
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
|
||||
@ -8,9 +8,16 @@ from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -32,8 +39,9 @@ class Subscription(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
@ -45,8 +53,9 @@ class Invoices(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
@ -63,9 +72,8 @@ class PartnerTenants(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, partner_key: str):
|
||||
try:
|
||||
args = PartnerTenantsPayload.model_validate(console_ns.payload or {})
|
||||
click_id = args.click_id
|
||||
|
||||
@ -3,11 +3,18 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from ..wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ComplianceDownloadQuery(BaseModel):
|
||||
@ -29,8 +36,9 @@ class ComplianceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
|
||||
@ -1,41 +1,37 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_serializer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse, TextContentResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_model
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.entities.knowledge_entities import IndexingEstimate
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import (
|
||||
integrate_fields,
|
||||
integrate_icon_fields,
|
||||
integrate_list_fields,
|
||||
integrate_notion_info_list_fields,
|
||||
integrate_page_fields,
|
||||
integrate_workspace_fields,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
@ -54,50 +50,74 @@ class DataSourceNotionPreviewQuery(BaseModel):
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
|
||||
|
||||
register_schema_model(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse, TextContentResponse)
|
||||
class DataSourceIntegrateIconResponse(ResponseModel):
|
||||
type: str | None = None
|
||||
url: str | None = None
|
||||
emoji: str | None = None
|
||||
|
||||
|
||||
integrate_icon_model = get_or_create_model("DataSourceIntegrateIcon", integrate_icon_fields)
|
||||
class DataSourceIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str
|
||||
type: str
|
||||
|
||||
integrate_page_fields_copy = integrate_page_fields.copy()
|
||||
integrate_page_fields_copy["page_icon"] = fields.Nested(integrate_icon_model, allow_null=True)
|
||||
integrate_page_model = get_or_create_model("DataSourceIntegratePage", integrate_page_fields_copy)
|
||||
|
||||
integrate_workspace_fields_copy = integrate_workspace_fields.copy()
|
||||
integrate_workspace_fields_copy["pages"] = fields.List(fields.Nested(integrate_page_model))
|
||||
integrate_workspace_model = get_or_create_model("DataSourceIntegrateWorkspace", integrate_workspace_fields_copy)
|
||||
class DataSourceIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[DataSourceIntegratePageResponse]
|
||||
total: int
|
||||
|
||||
integrate_fields_copy = integrate_fields.copy()
|
||||
integrate_fields_copy["source_info"] = fields.Nested(integrate_workspace_model)
|
||||
integrate_model = get_or_create_model("DataSourceIntegrate", integrate_fields_copy)
|
||||
|
||||
integrate_list_fields_copy = integrate_list_fields.copy()
|
||||
integrate_list_fields_copy["data"] = fields.List(fields.Nested(integrate_model))
|
||||
integrate_list_model = get_or_create_model("DataSourceIntegrateList", integrate_list_fields_copy)
|
||||
class DataSourceIntegrateResponse(ResponseModel):
|
||||
id: str | None
|
||||
provider: str
|
||||
created_at: datetime | int | None
|
||||
is_bound: bool
|
||||
disabled: bool | None
|
||||
link: str
|
||||
source_info: DataSourceIntegrateWorkspaceResponse | None
|
||||
|
||||
notion_page_fields = {
|
||||
"page_name": fields.String,
|
||||
"page_id": fields.String,
|
||||
"page_icon": fields.Nested(integrate_icon_model, allow_null=True),
|
||||
"is_bound": fields.Boolean,
|
||||
"parent_id": fields.String,
|
||||
"type": fields.String,
|
||||
}
|
||||
notion_page_model = get_or_create_model("NotionIntegratePage", notion_page_fields)
|
||||
@field_serializer("created_at")
|
||||
def serialize_created_at(self, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
notion_workspace_fields = {
|
||||
"workspace_name": fields.String,
|
||||
"workspace_id": fields.String,
|
||||
"workspace_icon": fields.String,
|
||||
"pages": fields.List(fields.Nested(notion_page_model)),
|
||||
}
|
||||
notion_workspace_model = get_or_create_model("NotionIntegrateWorkspace", notion_workspace_fields)
|
||||
|
||||
integrate_notion_info_list_fields_copy = integrate_notion_info_list_fields.copy()
|
||||
integrate_notion_info_list_fields_copy["notion_info"] = fields.List(fields.Nested(notion_workspace_model))
|
||||
integrate_notion_info_list_model = get_or_create_model(
|
||||
"NotionIntegrateInfoList", integrate_notion_info_list_fields_copy
|
||||
class DataSourceIntegrateListResponse(ResponseModel):
|
||||
data: list[DataSourceIntegrateResponse]
|
||||
|
||||
|
||||
class NotionIntegratePageResponse(ResponseModel):
|
||||
page_name: str
|
||||
page_id: str
|
||||
page_icon: DataSourceIntegrateIconResponse | None
|
||||
parent_id: str | None
|
||||
type: str
|
||||
is_bound: bool
|
||||
|
||||
|
||||
class NotionIntegrateWorkspaceResponse(ResponseModel):
|
||||
workspace_name: str | None
|
||||
workspace_id: str | None
|
||||
workspace_icon: str | None
|
||||
pages: list[NotionIntegratePageResponse]
|
||||
|
||||
|
||||
class NotionIntegrateInfoListResponse(ResponseModel):
|
||||
notion_info: list[NotionIntegrateWorkspaceResponse]
|
||||
|
||||
|
||||
register_schema_models(console_ns, NotionEstimatePayload)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DataSourceIntegrateListResponse,
|
||||
IndexingEstimate,
|
||||
NotionIntegrateInfoListResponse,
|
||||
SimpleResultResponse,
|
||||
TextContentResponse,
|
||||
)
|
||||
|
||||
|
||||
@ -109,10 +129,9 @@ class DataSourceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_model)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -154,19 +173,21 @@ class DataSourceApi(Resource):
|
||||
"link": f"{base_url}{data_source_oauth_base_path}/{provider}",
|
||||
}
|
||||
)
|
||||
return {"data": integrate_data}, 200
|
||||
return dump_response(DataSourceIntegrateListResponse, {"data": integrate_data}), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
|
||||
) -> tuple[dict[str, str], int]:
|
||||
binding_id_str = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
DataSourceOauthBinding.id == binding_id_str, DataSourceOauthBinding.tenant_id == current_tenant_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
@ -198,12 +219,12 @@ class DataSourceNotionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_model)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -278,22 +299,22 @@ class DataSourceNotionListApi(Resource):
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
|
||||
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/notion/pages/<uuid:page_id>/<string:page_type>/preview",
|
||||
"/datasets/notion-indexing-estimate",
|
||||
)
|
||||
class DataSourceNotionApi(Resource):
|
||||
@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
|
||||
class DataSourceNotionPreviewApi(Resource):
|
||||
"""Preview one authorized Notion page through the datasource credential."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
@ -316,13 +337,18 @@ class DataSourceNotionApi(Resource):
|
||||
text_docs = extractor.extract()
|
||||
return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/notion-indexing-estimate")
|
||||
class DataSourceNotionIndexingEstimateApi(Resource):
|
||||
"""Estimate indexing work for selected Notion pages."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
# validate args
|
||||
@ -355,7 +381,7 @@ class DataSourceNotionApi(Resource):
|
||||
args["doc_form"],
|
||||
args["doc_language"],
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
return dump_response(IndexingEstimate, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/notion/sync")
|
||||
@ -364,7 +390,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -382,7 +408,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -9,7 +9,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import asc, desc, func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -34,16 +34,18 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.document_fields import (
|
||||
document_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentStatusListResponse,
|
||||
DocumentStatusResponse,
|
||||
normalize_enum,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@ -69,17 +71,13 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_enum(value: Any) -> Any:
|
||||
if isinstance(value, str) or value is None:
|
||||
return value
|
||||
return getattr(value, "value", value)
|
||||
|
||||
|
||||
class DatasetResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -93,7 +91,7 @@ class DatasetResponse(ResponseModel):
|
||||
@field_validator("data_source_type", "indexing_technique", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
return normalize_enum(value)
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
@ -101,61 +99,10 @@ class DatasetResponse(ResponseModel):
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentMetadataResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class DocumentResponse(ResponseModel):
|
||||
id: str
|
||||
position: int | None = None
|
||||
data_source_type: str | None = None
|
||||
data_source_info: Any = Field(default=None, validation_alias="data_source_info_dict")
|
||||
data_source_detail_dict: Any = None
|
||||
dataset_process_rule_id: str | None = None
|
||||
name: str
|
||||
created_from: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
tokens: int | None = None
|
||||
indexing_status: str | None = None
|
||||
error: str | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
archived: bool | None = None
|
||||
display_status: str | None = None
|
||||
word_count: int | None = None
|
||||
hit_count: int | None = None
|
||||
doc_form: str | None = None
|
||||
doc_metadata: list[DocumentMetadataResponse] = Field(default_factory=list, validation_alias="doc_metadata_details")
|
||||
summary_index_status: str | None = None
|
||||
need_summary: bool | None = None
|
||||
|
||||
@field_validator("data_source_type", "indexing_status", "display_status", "doc_form", mode="before")
|
||||
@classmethod
|
||||
def _normalize_enum_fields(cls, value: Any) -> Any:
|
||||
return _normalize_enum(value)
|
||||
|
||||
@field_validator("doc_metadata", mode="before")
|
||||
@classmethod
|
||||
def _normalize_doc_metadata(cls, value: Any) -> list[Any]:
|
||||
if value is None:
|
||||
return []
|
||||
return value
|
||||
|
||||
@field_validator("created_at", "disabled_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentWithSegmentsResponse(DocumentResponse):
|
||||
process_rule_dict: Any = None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
completed_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
total_segments: int | None = Field(default=None, exclude_if=lambda value: value is None)
|
||||
|
||||
|
||||
class DatasetAndDocumentResponse(ResponseModel):
|
||||
@ -190,6 +137,14 @@ class DocumentDatasetListParam(BaseModel):
|
||||
fetch_val: str = Field("false", alias="fetch")
|
||||
|
||||
|
||||
class DocumentWithSegmentsListResponse(ResponseModel):
|
||||
data: list[DocumentWithSegmentsResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
KnowledgeConfig,
|
||||
@ -200,18 +155,25 @@ register_schema_models(
|
||||
GenerateSummaryPayload,
|
||||
DocumentMetadataUpdatePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SimpleResultMessageResponse,
|
||||
SimpleResultResponse,
|
||||
UrlResponse,
|
||||
DatasetResponse,
|
||||
DocumentMetadataResponse,
|
||||
DocumentResponse,
|
||||
DocumentWithSegmentsResponse,
|
||||
DatasetAndDocumentResponse,
|
||||
DocumentWithSegmentsListResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultMessageResponse, SimpleResultResponse, UrlResponse)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
def get_document(
|
||||
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
|
||||
) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -231,8 +193,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -259,8 +220,8 @@ class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -312,12 +273,17 @@ class DatasetDocumentListApi(Resource):
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Documents retrieved successfully",
|
||||
console_ns.models[DocumentWithSegmentsListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
@ -425,18 +391,15 @@ class DatasetDocumentListApi(Resource):
|
||||
)
|
||||
document.completed_segments = completed_segments
|
||||
document.total_segments = total_segments
|
||||
data = marshal(documents, document_with_segments_fields)
|
||||
else:
|
||||
data = marshal(documents, document_fields)
|
||||
response = {
|
||||
"data": data,
|
||||
"data": documents,
|
||||
"has_more": len(documents) == limit,
|
||||
"limit": limit,
|
||||
"total": paginated_documents.total,
|
||||
"page": page,
|
||||
}
|
||||
|
||||
return response
|
||||
return dump_response(DocumentWithSegmentsListResponse, response)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -445,8 +408,8 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -482,9 +445,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -522,9 +483,10 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -567,9 +529,7 @@ class DatasetInitApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return DatasetAndDocumentResponse.model_validate(
|
||||
{"dataset": dataset, "documents": documents, "batch": batch}, from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
return dump_response(DatasetAndDocumentResponse, {"dataset": dataset, "documents": documents, "batch": batch})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
@ -583,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -648,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -742,12 +704,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/batch/<string:batch>/indexing-status")
|
||||
class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusListResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -784,9 +750,8 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status})
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
@ -794,21 +759,25 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200, "Indexing status retrieved successfully", console_ns.models[DocumentStatusResponse.__name__]
|
||||
)
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -817,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -839,7 +808,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
return marshal(document_dict, document_status_fields)
|
||||
return dump_response(DocumentStatusResponse, document_dict)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
|
||||
@ -860,10 +829,12 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -949,7 +920,9 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -958,7 +931,7 @@ class DocumentApi(DocumentResource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -979,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@ -996,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
@ -1043,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["pause", "resume"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1091,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1140,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(
|
||||
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -1256,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in payload.document_ids:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
@ -1288,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1304,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return DocumentResponse.model_validate(document, from_attributes=True).model_dump(mode="json")
|
||||
return dump_response(DocumentResponse, document)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
@ -1313,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -1391,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1399,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
@ -1484,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1497,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from typing import cast as type_cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -13,7 +14,12 @@ import services
|
||||
from configs import dify_config
|
||||
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
query_params_from_request,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from controllers.console.datasets.error import (
|
||||
@ -27,6 +33,8 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@ -34,30 +42,29 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.base import ResponseModel
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from fields.segment_fields import (
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkResponse,
|
||||
SegmentDetailResponse,
|
||||
SegmentResponse,
|
||||
segment_response_with_summary,
|
||||
segment_responses_with_summaries,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -67,6 +74,16 @@ class SegmentListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentIdListQuery(BaseModel):
|
||||
segment_id: list[str] = Field(default_factory=list, description="Segment IDs")
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
content: str
|
||||
answer: str | None = None
|
||||
@ -92,13 +109,35 @@ class SegmentBatchImportStatusResponse(ResponseModel):
|
||||
job_status: str
|
||||
|
||||
|
||||
class ConsoleSegmentListResponse(ResponseModel):
|
||||
data: list[SegmentResponse]
|
||||
limit: int
|
||||
total: int
|
||||
total_pages: int
|
||||
page: int
|
||||
|
||||
|
||||
class ChildChunkBatchUpdateResponse(ResponseModel):
|
||||
data: list[ChildChunkResponse]
|
||||
|
||||
|
||||
class ChildChunkBatchUpdatePayload(BaseModel):
|
||||
chunks: list[ChildChunkUpdateArgs]
|
||||
|
||||
|
||||
class SegmentDocParams:
|
||||
DATASET_DOCUMENT = {"dataset_id": "Dataset ID", "document_id": "Document ID"}
|
||||
DATASET_DOCUMENT_ACTION = {**DATASET_DOCUMENT, "action": "Action"}
|
||||
DATASET_DOCUMENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Segment ID"}
|
||||
DATASET_DOCUMENT_PARENT_SEGMENT = {**DATASET_DOCUMENT, "segment_id": "Parent segment ID"}
|
||||
DATASET_DOCUMENT_CHILD_CHUNK = {**DATASET_DOCUMENT_PARENT_SEGMENT, "child_chunk_id": "Child chunk ID"}
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SegmentListQuery,
|
||||
SegmentIdListQuery,
|
||||
ChildChunkListQuery,
|
||||
SegmentCreatePayload,
|
||||
SegmentUpdatePayload,
|
||||
BatchImportPayload,
|
||||
@ -107,17 +146,30 @@ register_schema_models(
|
||||
ChildChunkBatchUpdatePayload,
|
||||
ChildChunkUpdateArgs,
|
||||
)
|
||||
register_response_schema_models(console_ns, SegmentBatchImportStatusResponse, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
SegmentResponse,
|
||||
ConsoleSegmentListResponse,
|
||||
SegmentDetailResponse,
|
||||
ChildChunkDetailResponse,
|
||||
ChildChunkListResponse,
|
||||
ChildChunkBatchUpdateResponse,
|
||||
SegmentBatchImportStatusResponse,
|
||||
SimpleResultResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
|
||||
class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentListQuery))
|
||||
@console_ns.response(200, "Segments retrieved successfully", console_ns.models[ConsoleSegmentListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -134,12 +186,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
**request.args.to_dict(),
|
||||
"status": request.args.getlist("status"),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(SegmentListQuery, list_fields=("status",))
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -169,9 +216,17 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
@ -197,42 +252,33 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
segment_list = list(segments.items)
|
||||
segment_ids = [segment.id for segment in segment_list]
|
||||
summaries: dict[str, str | None] = {}
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": segments_with_summary,
|
||||
"data": segment_responses_with_summaries(segment_list, summaries),
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(ConsoleSegmentListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -260,15 +306,24 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
class DatasetDocumentSegmentApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_ACTION)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["enable", "disable"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -313,11 +368,12 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
SegmentService.update_segments_status(segment_ids, action, dataset, document)
|
||||
except Exception as e:
|
||||
raise InvalidActionError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
|
||||
class DatasetDocumentSegmentAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -325,9 +381,10 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -364,21 +421,30 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload = SegmentCreatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset))
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -432,16 +498,24 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str)
|
||||
response = {
|
||||
"data": segment_response_with_summary(segment, summary.summary_content if summary else None),
|
||||
"doc_form": document.doc_form,
|
||||
}
|
||||
return dump_response(SegmentDetailResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -487,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -515,11 +589,11 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
try:
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = f"segment_batch_import_{str(job_id)}"
|
||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id),
|
||||
job_id,
|
||||
upload_file_id,
|
||||
dataset_id_str,
|
||||
document_id_str,
|
||||
@ -528,7 +602,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
return {"job_id": job_id, "job_status": "waiting"}, 200
|
||||
return dump_response(SegmentBatchImportStatusResponse, {"job_id": job_id, "job_status": "waiting"}), 200
|
||||
|
||||
@console_ns.response(200, "Batch import status", console_ns.models[SegmentBatchImportStatusResponse.__name__])
|
||||
@setup_required
|
||||
@ -543,11 +617,13 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
|
||||
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
|
||||
response = {"job_id": job_id, "job_status": cache_result.decode()}
|
||||
return dump_response(SegmentBatchImportStatusResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks")
|
||||
class ChildChunkAddApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -555,9 +631,12 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -605,14 +684,16 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.doc(params=query_params_from_model(ChildChunkListQuery))
|
||||
@console_ns.response(200, "Child chunks retrieved successfully", console_ns.models[ChildChunkListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -634,13 +715,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
args = query_params_from_request(ChildChunkListQuery, use_defaults_for_malformed_ints=True)
|
||||
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
@ -649,22 +724,32 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
response = {
|
||||
"data": child_chunks.items,
|
||||
"total": child_chunks.total,
|
||||
"total_pages": child_chunks.pages,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
}, 200
|
||||
}
|
||||
return dump_response(ChildChunkListResponse, response), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_PARENT_SEGMENT)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Child chunks updated successfully",
|
||||
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -699,7 +784,7 @@ class ChildChunkAddApi(Resource):
|
||||
child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -710,10 +795,19 @@ class ChildChunkUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -740,7 +834,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -767,10 +861,20 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -797,7 +901,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == child_chunk_id_str,
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
@ -819,4 +923,4 @@ class ChildChunkUpdateApi(Resource):
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -29,7 +30,8 @@ from fields.dataset_fields import (
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -20,86 +17,8 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
register_schema_models(console_ns, HitTestingPayload)
|
||||
register_response_schema_models(console_ns, HitTestingResponse)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -119,12 +38,11 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: UUID) -> dict[str, object]:
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args = self.parse_args(console_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -19,10 +18,10 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -38,16 +37,6 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
@ -63,6 +52,7 @@ class DatasetsHitTestingBase:
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
normalized_segment.setdefault("sign_content", None)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
@ -73,12 +63,15 @@ class DatasetsHitTestingBase:
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_record.setdefault("tsne_position", None)
|
||||
normalized_record.setdefault("summary", None)
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
def get_and_validate_dataset(dataset_id: str) -> Dataset:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -92,33 +85,35 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
def hit_testing_args_check(args: dict[str, Any]) -> None:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args.get("query"),
|
||||
query=cast(str, args.get("query")),
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
query = response.get("query")
|
||||
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
"query": {"content": query["content"]},
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
|
||||
@ -7,14 +7,20 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -43,8 +49,8 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -80,8 +86,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
|
||||
@ -100,8 +106,8 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -137,8 +143,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -165,8 +171,8 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -9,11 +9,18 @@ from configs import dify_config
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -195,15 +200,17 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, user: Account, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
user=user,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
@ -216,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -241,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -264,9 +269,8 @@ class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -277,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -292,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -310,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
@ -330,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -352,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@ -1,13 +1,20 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -16,79 +23,132 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineTemplateListQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
language: str = Field(default="en-US", description="Template language")
|
||||
|
||||
|
||||
class PipelineTemplateDetailQuery(BaseModel):
|
||||
type: str = Field(default="built-in", description="Template source: built-in or customized")
|
||||
|
||||
|
||||
class PipelineTemplateItemResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon: dict[str, Any]
|
||||
description: str
|
||||
position: int
|
||||
chunk_structure: str
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
|
||||
|
||||
class PipelineTemplateListResponse(ResponseModel):
|
||||
pipeline_templates: list[PipelineTemplateItemResponse]
|
||||
|
||||
|
||||
class PipelineTemplateDetailResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
icon_info: dict[str, Any]
|
||||
description: str
|
||||
chunk_structure: str
|
||||
export_data: str
|
||||
graph: dict[str, Any]
|
||||
created_by: str | None = None
|
||||
|
||||
|
||||
class CustomizedPipelineTemplatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] = Field(default_factory=lambda: IconInfo(icon="").model_dump())
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
CustomizedPipelineTemplatePayload,
|
||||
PipelineTemplateDetailQuery,
|
||||
PipelineTemplateListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
PipelineTemplateDetailResponse,
|
||||
PipelineTemplateListResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateListQuery))
|
||||
@console_ns.response(200, "Pipeline templates", console_ns.models[PipelineTemplateListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
def get(self) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language)
|
||||
return dump_response(PipelineTemplateListResponse, pipeline_templates), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates/<string:template_id>")
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(PipelineTemplateDetailQuery))
|
||||
@console_ns.response(200, "Pipeline template", console_ns.models[PipelineTemplateDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
def get(self, template_id: str) -> JsonResponseWithStatus:
|
||||
query = PipelineTemplateDetailQuery.model_validate(request.args.to_dict(flat=True))
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, query.type)
|
||||
if pipeline_template is None:
|
||||
return {"error": "Pipeline template not found from upstream service."}, 404
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
register_response_schema_models(console_ns, SimpleDataResponse)
|
||||
raise NotFound("Pipeline template not found from upstream service.")
|
||||
return dump_response(PipelineTemplateDetailResponse, pipeline_template), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template updated")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def patch(self, template_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(204, "Pipeline template deleted")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
def delete(self, template_id: str) -> tuple[str, int]:
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleDataResponse.__name__])
|
||||
def post(self, template_id: str):
|
||||
def post(self, template_id: str) -> JsonResponseWithStatus:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||
@ -96,19 +156,20 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": template.yaml_content}), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@console_ns.expect(console_ns.models[CustomizedPipelineTemplatePayload.__name__])
|
||||
@console_ns.response(204, "Pipeline template published")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
def post(self, pipeline_id: str) -> tuple[str, int]:
|
||||
payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
return {"result": "success"}
|
||||
return "", 204
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.common.schema import JsonResponseWithStatus, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import RagPipelineImportResponse
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -25,19 +30,26 @@ class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_schema_models(console_ns, RagPipelineDatasetImportPayload)
|
||||
register_response_schema_models(console_ns, DatasetDetailResponse, RagPipelineImportResponse)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@console_ns.response(
|
||||
201,
|
||||
"RAG pipeline dataset import started",
|
||||
console_ns.models[RagPipelineImportResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -70,19 +82,20 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
return dump_response(RagPipelineImportResponse, import_info), 201
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/empty-dataset")
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@console_ns.response(201, "RAG pipeline dataset created", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWithStatus:
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
@ -99,4 +112,4 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -57,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -72,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -1,23 +1,29 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.common.fields import SimpleDataResponse
|
||||
from controllers.common.schema import (
|
||||
JsonResponseWithStatus,
|
||||
query_params_from_model,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import (
|
||||
leaked_dependency_fields,
|
||||
pipeline_import_check_dependencies_fields,
|
||||
pipeline_import_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.entities.dsl_entities import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -36,35 +42,45 @@ class RagPipelineImportPayload(BaseModel):
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
include_secret: str = Field(default="false")
|
||||
include_secret: str = Field(default="false", description="Whether to include secret values in the exported DSL")
|
||||
|
||||
|
||||
class RagPipelineImportResponse(ResponseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: str | None = None
|
||||
dataset_id: str | None = None
|
||||
current_dsl_version: str
|
||||
imported_dsl_version: str
|
||||
error: str = ""
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesResponse(ResponseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
pipeline_import_model = get_or_create_model("RagPipelineImport", pipeline_import_fields)
|
||||
|
||||
leaked_dependency_model = get_or_create_model("RagPipelineLeakedDependency", leaked_dependency_fields)
|
||||
pipeline_import_check_dependencies_fields_copy = pipeline_import_check_dependencies_fields.copy()
|
||||
pipeline_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(
|
||||
fields.Nested(leaked_dependency_model)
|
||||
)
|
||||
pipeline_import_check_dependencies_model = get_or_create_model(
|
||||
"RagPipelineImportCheckDependencies", pipeline_import_check_dependencies_fields_copy
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
RagPipelineImportCheckDependenciesResponse,
|
||||
RagPipelineImportResponse,
|
||||
SimpleDataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
@console_ns.response(200, "Import completed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(202, "Import pending confirmation", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account) -> JsonResponseWithStatus:
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Use a plain Session so that caught exceptions inside the service
|
||||
@ -91,23 +107,23 @@ class RagPipelineImportApi(Resource):
|
||||
status = result.status
|
||||
match status:
|
||||
case ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
case ImportStatus.PENDING:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return dump_response(RagPipelineImportResponse, result), 202
|
||||
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@console_ns.response(200, "Import confirmed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@console_ns.response(400, "Import failed", console_ns.models[RagPipelineImportResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, import_id: str) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
account = current_user
|
||||
@ -119,34 +135,40 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportResponse, result), 400
|
||||
return dump_response(RagPipelineImportResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports/<string:pipeline_id>/check-dependencies")
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dependencies checked",
|
||||
console_ns.models[RagPipelineImportCheckDependenciesResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
return dump_response(RagPipelineImportCheckDependenciesResponse, result), 200
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/exports")
|
||||
class RagPipelineExportApi(Resource):
|
||||
@console_ns.doc(params=query_params_from_model(IncludeSecretQuery))
|
||||
@console_ns.response(200, "Pipeline exported", console_ns.models[SimpleDataResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
def get(self, pipeline: Pipeline) -> JsonResponseWithStatus:
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@ -156,4 +178,4 @@ class RagPipelineExportApi(Resource):
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
return dump_response(SimpleDataResponse, {"data": result}), 200
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
@ -12,14 +12,19 @@ from controllers.common.fields import SimpleMessageResponse, SimpleResultMessage
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from libs.login import login_required
|
||||
from models import Account, App, InstalledApp, RecommendedApp
|
||||
from models.model import IconType
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -131,9 +136,10 @@ class InstalledAppsListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[InstalledAppListResponse.__name__])
|
||||
def get(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if query.app_id:
|
||||
installed_apps = db.session.scalars(
|
||||
@ -212,7 +218,8 @@ class InstalledAppsListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleMessageResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
recommended_app = db.session.scalar(
|
||||
@ -221,8 +228,6 @@ class InstalledAppsListApi(Resource):
|
||||
if recommended_app is None:
|
||||
raise NotFound("Recommended app not found")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
app = db.session.get(App, payload.app_id)
|
||||
|
||||
if app is None:
|
||||
@ -262,8 +267,8 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, installed_app: InstalledApp):
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
|
||||
|
||||
@ -2,13 +2,36 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
from services.feature_service import (
|
||||
FeatureModel,
|
||||
FeatureService,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
)
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
class TrialModelsResponse(ResponseModel):
|
||||
trial_models: list[str]
|
||||
|
||||
|
||||
class AppDslVersionResponse(ResponseModel):
|
||||
app_dsl_version: str
|
||||
|
||||
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AppDslVersionResponse,
|
||||
FeatureModel,
|
||||
LimitationModel,
|
||||
SystemFeatureModel,
|
||||
TrialModelsResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -54,6 +77,43 @@ class FeatureVectorSpaceApi(Resource):
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/trial-models")
|
||||
class TrialModelsApi(Resource):
|
||||
@console_ns.doc("get_trial_models")
|
||||
@console_ns.doc(description="Get hosted trial model provider configuration")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[TrialModelsResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Get hosted trial model provider configuration for model-provider pages."""
|
||||
return dump_response(
|
||||
TrialModelsResponse,
|
||||
{"trial_models": FeatureService.get_trial_models()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app-dsl-version")
|
||||
class AppDslVersionApi(Resource):
|
||||
@console_ns.doc("get_app_dsl_version")
|
||||
@console_ns.doc(description="Get current app DSL version")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[AppDslVersionResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
"""Get current app DSL version for workflow clipboard compatibility."""
|
||||
return dump_response(
|
||||
AppDslVersionResponse,
|
||||
{"app_dsl_version": FeatureService.get_app_dsl_version()},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
class SystemFeatureApi(Resource):
|
||||
@console_ns.doc("get_system_features")
|
||||
|
||||
@ -12,8 +12,15 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
model_validate,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
@ -22,8 +29,8 @@ from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from libs.login import login_required
|
||||
from models import Account, App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
@ -33,6 +40,8 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(console_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
@ -45,9 +54,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
def _ensure_console_access(form: Form, current_tenant_id: str) -> None:
|
||||
"""Ensure a console form token resolves only inside the current tenant."""
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@ -59,7 +67,8 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
@ -70,13 +79,23 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
@model_validate(HumanInputFormSubmitPayload)
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
def post(
|
||||
self,
|
||||
payload: HumanInputFormSubmitPayload,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
form_token: str,
|
||||
):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
@ -90,15 +109,12 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_access(form, current_tenant_id)
|
||||
self._ensure_console_recipient_type(form)
|
||||
recipient_type = form.recipient_type
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
@ -122,7 +138,9 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
@ -130,8 +148,6 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from core.file import remote_fetcher
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
@ -36,9 +36,9 @@ class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
resp = remote_fetcher.make_request("HEAD", decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp = remote_fetcher.make_request("GET", decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
@ -58,9 +58,9 @@ class RemoteFileUpload(Resource):
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
resp = remote_fetcher.make_request("HEAD", url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp = remote_fetcher.make_request("GET", url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
@ -74,7 +74,7 @@ class RemoteFileUpload(Resource):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
content = resp.content if resp.request.method == "GET" else remote_fetcher.make_request("GET", url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
|
||||
@ -18,7 +18,7 @@ from controllers.common.fields import (
|
||||
SimpleResultResponse,
|
||||
VerificationTokenResponse,
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
@ -42,15 +42,17 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -173,7 +175,6 @@ class CheckEmailUniquePayload(BaseModel):
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AccountInitPayload,
|
||||
AccountNamePayload,
|
||||
AccountAvatarPayload,
|
||||
@ -245,6 +246,7 @@ register_schema_models(
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
AccountResponse,
|
||||
AvatarUrlResponse,
|
||||
SimpleResultDataResponse,
|
||||
SimpleResultResponse,
|
||||
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
@ -329,20 +330,21 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@console_ns.doc(params=query_params_from_model(AccountAvatarQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[AvatarUrlResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar})
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
@ -355,15 +357,15 @@ class AccountAvatarApi(Resource):
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
return dump_response(AvatarUrlResponse, {"avatar_url": avatar_url})
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
account_integrates = db.session.scalars(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||
).all()
|
||||
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_deletion_verification_email(account, code)
|
||||
|
||||
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
@ -563,9 +561,8 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
@ -577,9 +574,8 @@ class EducationApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -19,14 +25,10 @@ class AgentProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
return jsonable_encoder(AgentService.list_agent_providers(current_user.id, current_tenant_id))
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
|
||||
@ -42,6 +44,7 @@ class AgentProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_name: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_name: str):
|
||||
return jsonable_encoder(AgentService.get_agent_provider(current_user.id, current_tenant_id, provider_name))
|
||||
|
||||
@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
@ -96,17 +102,15 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the injected workspace and user."""
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, user_id: str, id: str):
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, user_id: str, id: str):
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -4,11 +4,16 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from libs.login import login_required
|
||||
from models import Account, TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -29,8 +34,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
@ -72,8 +78,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@ -25,12 +25,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_allow_transfer_owner,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
@ -136,8 +137,8 @@ class MemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
@ -154,7 +155,8 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
@ -163,7 +165,6 @@ class MemberInviteEmailApi(Resource):
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -223,8 +224,8 @@ class MemberCancelInviteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, member_id: UUID):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
member = db.session.get(Account, str(member_id))
|
||||
@ -256,14 +257,14 @@ class MemberUpdateRoleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def put(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
|
||||
if not TenantAccountRole.is_valid_role(new_role):
|
||||
return {"code": "invalid-role", "message": "Invalid role"}, 400
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not _is_role_enabled(new_role, current_user.current_tenant.id):
|
||||
@ -297,8 +298,8 @@ class DatasetOperatorMemberListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
|
||||
@ -317,13 +318,13 @@ class SendOwnerTransferEmailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferEmailPayload.model_validate(payload)
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# check if the current user is the owner of the workspace
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -355,11 +356,11 @@ class OwnerTransferCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferCheckPayload.model_validate(payload)
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
@ -399,12 +400,12 @@ class OwnerTransfer(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, member_id: UUID):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
# check if the current user is the owner of the workspace
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
if not TenantService.is_owner(current_user, current_user.current_tenant):
|
||||
|
||||
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -95,10 +102,8 @@ class ModelProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
@ -114,9 +119,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
@ -133,8 +137,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialCreate.model_validate(payload)
|
||||
|
||||
@ -157,9 +161,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialUpdate.model_validate(payload)
|
||||
|
||||
@ -184,8 +187,8 @@ class ModelProviderCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialDelete.model_validate(payload)
|
||||
|
||||
@ -205,8 +208,8 @@ class ModelProviderCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialSwitch.model_validate(payload)
|
||||
|
||||
@ -225,8 +228,8 @@ class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserCredentialValidate.model_validate(payload)
|
||||
|
||||
@ -280,11 +283,8 @@ class PreferredProviderTypeUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = ParserPreferredProviderType.model_validate(payload)
|
||||
|
||||
@ -301,10 +301,11 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider: str):
|
||||
if provider != "anthropic":
|
||||
raise ValueError(f"provider name {provider} is invalid")
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
data = BillingService.get_model_provider_payment_link(
|
||||
provider_name=provider,
|
||||
|
||||
@ -8,12 +8,19 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -138,9 +145,8 @@ class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -156,9 +162,8 @@ class DefaultModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args.model_settings
|
||||
@ -189,9 +194,8 @@ class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -202,9 +206,9 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.config_from == "custom-model":
|
||||
@ -249,9 +253,8 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -268,9 +271,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user: Account, provider: str):
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -292,9 +295,13 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
)
|
||||
|
||||
if args.config_from == "predefined-model":
|
||||
# Only the predefined-model branch needs visibility filtering by user.
|
||||
# The account is injected once by the handler and only passed into the
|
||||
# service branch that needs user-scoped credential visibility.
|
||||
available_credentials = model_provider_service.get_provider_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
@ -323,9 +330,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -355,8 +361,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -382,8 +388,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -406,8 +412,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
@ -430,9 +436,8 @@ class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -452,9 +457,8 @@ class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -480,8 +484,8 @@ class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -515,9 +519,9 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
@ -532,8 +536,8 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, model_type: str):
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ class BuiltinToolAddPayload(BaseModel):
|
||||
credentials: dict[str, Any]
|
||||
name: str | None = Field(default=None, max_length=30)
|
||||
type: CredentialType
|
||||
visibility: str | None = None
|
||||
|
||||
|
||||
class BuiltinToolUpdatePayload(BaseModel):
|
||||
@ -277,7 +278,7 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(
|
||||
@ -293,7 +294,7 @@ class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
|
||||
@ -306,7 +307,7 @@ class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = BuiltinToolCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
@ -324,7 +325,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
user_id = user.id
|
||||
@ -338,6 +339,7 @@ class ToolBuiltinProviderAddApi(Resource):
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
api_type=CredentialType.of(payload.type),
|
||||
visibility=payload.visibility,
|
||||
)
|
||||
|
||||
|
||||
@ -348,7 +350,7 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
user_id = user.id
|
||||
|
||||
@ -370,13 +372,20 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
# Optional list of credential IDs to include even if visibility would hide them
|
||||
# (used when a workflow/agent node still references another member's only_me credential).
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -384,7 +393,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/icon")
|
||||
class ToolBuiltinProviderIconApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider)
|
||||
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||
@ -784,7 +793,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
@ -822,7 +831,7 @@ class ToolPluginOAuthApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/tool/callback")
|
||||
class ToolOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
@ -859,7 +868,7 @@ class ToolOAuthCallback(Resource):
|
||||
if not credentials:
|
||||
raise Exception("the plugin credentials failed")
|
||||
|
||||
# add credentials to database
|
||||
# add credentials to database — OAuth tokens default to only_me since they're personal
|
||||
BuiltinToolManageService.add_builtin_tool_provider(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@ -867,6 +876,7 @@ class ToolOAuthCallback(Resource):
|
||||
credentials=dict(credentials),
|
||||
expires_at=expires_at,
|
||||
api_type=CredentialType.OAUTH2,
|
||||
visibility="only_me",
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
@ -878,7 +888,7 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
@ -910,7 +920,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -919,7 +929,7 @@ class ToolOAuthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.delete_custom_oauth_client_params(tenant_id=current_tenant_id, provider=provider)
|
||||
@ -931,7 +941,7 @@ class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
|
||||
@ -945,13 +955,18 @@ class ToolBuiltinProviderGetCredentialInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
def get(self, provider: str):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
include_credential_ids = request.args.getlist("include_credential_ids") or [
|
||||
s for s in (request.args.get("include_credential_ids") or "").split(",") if s
|
||||
]
|
||||
|
||||
return jsonable_encoder(
|
||||
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
user=user,
|
||||
include_credential_ids=include_credential_ids or None,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1151,7 +1166,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
@ -1180,7 +1195,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
def get(self, provider_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
|
||||
@ -77,7 +77,7 @@ class TriggerProviderIconApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
@ -103,7 +103,7 @@ class TriggerProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get info for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -119,15 +119,18 @@ class TriggerSubscriptionListApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""List all trigger subscriptions for the current tenant's provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
return jsonable_encoder(
|
||||
TriggerProviderService.list_trigger_provider_subscriptions(
|
||||
tenant_id=user.current_tenant_id, provider_id=TriggerProviderID(provider)
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=TriggerProviderID(provider),
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -146,7 +149,7 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Add a new subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -175,7 +178,7 @@ class TriggerSubscriptionBuilderGetApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get a subscription instance for a trigger provider"""
|
||||
return jsonable_encoder(
|
||||
TriggerSubscriptionBuilderService.get_subscription_builder_by_id(subscription_builder_id)
|
||||
@ -191,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Verify and update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -223,7 +226,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Update a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -257,7 +260,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, subscription_builder_id):
|
||||
def get(self, provider: str, subscription_builder_id: str):
|
||||
"""Get the request logs for a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -280,7 +283,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_builder_id):
|
||||
def post(self, provider: str, subscription_builder_id: str):
|
||||
"""Build a subscription instance for a trigger provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -404,7 +407,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Initiate OAuth authorization flow for a trigger provider"""
|
||||
user = current_user
|
||||
assert isinstance(user, Account)
|
||||
@ -486,7 +489,7 @@ class TriggerOAuthAuthorizeApi(Resource):
|
||||
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
|
||||
class TriggerOAuthCallbackApi(Resource):
|
||||
@setup_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Handle OAuth callback for trigger provider"""
|
||||
context_id = request.cookies.get("context_id")
|
||||
if not context_id:
|
||||
@ -554,7 +557,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
def get(self, provider: str):
|
||||
"""Get OAuth client configuration for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -600,7 +603,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
def post(self, provider: str):
|
||||
"""Configure custom OAuth client for a provider"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -626,7 +629,7 @@ class TriggerOAuthClientManageApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider):
|
||||
def delete(self, provider: str):
|
||||
"""Remove custom OAuth client configuration"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
@ -654,7 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, subscription_id):
|
||||
def post(self, provider: str, subscription_id: str):
|
||||
"""Verify credentials for an existing subscription (edit mode only)"""
|
||||
user = current_user
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
@ -25,13 +25,15 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -56,6 +58,11 @@ class WorkspaceCustomConfigPayload(BaseModel):
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceCustomConfigResponse(ResponseModel):
|
||||
remove_webapp_brand: bool | None = None
|
||||
replace_webapp_logo: str | None = None
|
||||
|
||||
|
||||
class WorkspaceInfoPayload(BaseModel):
|
||||
name: str
|
||||
|
||||
@ -69,7 +76,7 @@ class TenantInfoResponse(ResponseModel):
|
||||
role: str | None = None
|
||||
in_trial: bool | None = None
|
||||
trial_end_reason: str | None = None
|
||||
custom_config: dict | None = None
|
||||
custom_config: WorkspaceCustomConfigResponse | None = None
|
||||
trial_credits: int | None = None
|
||||
trial_credits_used: int | None = None
|
||||
next_credit_reset_date: int | None = None
|
||||
@ -101,9 +108,13 @@ register_schema_models(
|
||||
SwitchWorkspacePayload,
|
||||
WorkspaceCustomConfigPayload,
|
||||
WorkspaceInfoPayload,
|
||||
TenantInfoResponse,
|
||||
)
|
||||
register_response_schema_models(console_ns, WorkspacePermissionResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
TenantInfoResponse,
|
||||
WorkspaceCustomConfigResponse,
|
||||
WorkspacePermissionResponse,
|
||||
)
|
||||
|
||||
provider_fields = {
|
||||
"provider_name": fields.String,
|
||||
@ -144,8 +155,9 @@ class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
@ -219,11 +231,11 @@ class TenantApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tenant = current_user.current_tenant
|
||||
if not tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -238,13 +250,7 @@ class TenantApi(Resource):
|
||||
else:
|
||||
raise Unauthorized("workspace is archived")
|
||||
|
||||
return (
|
||||
TenantInfoResponse.model_validate(
|
||||
WorkspaceService.get_tenant_info(tenant),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
return dump_response(TenantInfoResponse, WorkspaceService.get_tenant_info(tenant)), 200
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/switch")
|
||||
@ -253,8 +259,8 @@ class SwitchWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = SwitchWorkspacePayload.model_validate(payload)
|
||||
|
||||
@ -278,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
@ -305,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@ -346,8 +352,8 @@ class WorkspaceInfoApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceInfoPayload.model_validate(payload)
|
||||
|
||||
@ -369,13 +375,12 @@ class WorkspacePermissionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
|
||||
@ -4,10 +4,12 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
from typing import Any, Concatenate, overload
|
||||
|
||||
from flask import abort, request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError
|
||||
@ -35,9 +37,21 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def account_initialization_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||
@ -216,9 +230,21 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def setup_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def setup_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
@ -512,9 +538,79 @@ def with_current_tenant_id[T, **P, R](
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
"""Inject the current authenticated Account into the handler as the first argument after self.
|
||||
|
||||
Usage::
|
||||
|
||||
class MyResource(Resource):
|
||||
@login_required
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user_id[T, **P, R](
|
||||
view: Callable[Concatenate[T, str, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
"""Inject the current authenticated user's ID (as a string) into the handler.
|
||||
|
||||
Use this when the handler only needs the user ID and not the full Account object.
|
||||
|
||||
Usage::
|
||||
|
||||
class MyResource(Resource):
|
||||
@login_required
|
||||
@with_current_user_id
|
||||
def get(self, current_user_id: str):
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user.id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def model_validate[T, M: BaseModel, **P, R](
|
||||
model: type[M],
|
||||
) -> Callable[
|
||||
[Callable[Concatenate[T, M, P], R]],
|
||||
Callable[Concatenate[T, P], R],
|
||||
]:
|
||||
"""Validate request data and inject the model instance as the first arg after self.
|
||||
|
||||
Source is determined by HTTP method:
|
||||
GET/DELETE -> request.args
|
||||
POST/PUT/PATCH -> JSON body
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
view: Callable[Concatenate[T, M, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if request.method in ("GET", "DELETE"):
|
||||
raw = request.args.to_dict(flat=True)
|
||||
else:
|
||||
raw = request.get_json(silent=True) or {}
|
||||
|
||||
try:
|
||||
validated = model.model_validate(raw)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
return view(self, validated, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@ -9,7 +9,6 @@ from werkzeug.exceptions import Forbidden
|
||||
import services
|
||||
from core.tools.signature import verify_plugin_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.file_reference import build_file_reference
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
from ..common.errors import (
|
||||
@ -59,8 +58,7 @@ class PluginUploadFileApi(Resource):
|
||||
The file must be accompanied by valid timestamp, nonce, and signature parameters.
|
||||
|
||||
Returns:
|
||||
dict: File metadata including ID, canonical ``reference`` for
|
||||
output-file reconstruction, URLs, and properties
|
||||
dict: File metadata including ID, URLs, and properties
|
||||
int: HTTP status code (201 for success)
|
||||
|
||||
Raises:
|
||||
@ -114,7 +112,6 @@ class PluginUploadFileApi(Resource):
|
||||
# Create a dictionary with all the necessary attributes
|
||||
result = FileResponse(
|
||||
id=tool_file.id,
|
||||
reference=build_file_reference(record_id=tool_file.id),
|
||||
name=tool_file.name,
|
||||
size=tool_file.size,
|
||||
extension=extension,
|
||||
|
||||
@ -12,7 +12,6 @@ from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestFetchAppInfo,
|
||||
RequestRequestDownloadFile,
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
@ -30,12 +29,10 @@ from core.plugin.entities.request import (
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.signature import get_signed_file_url_for_plugin
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import length_prefixed_response
|
||||
from models import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.file_request_service import FileRequestService
|
||||
|
||||
|
||||
@inner_api_ns.route("/invoke/llm")
|
||||
@ -432,54 +429,6 @@ class PluginUploadFileRequestApi(Resource):
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/download/file/request")
|
||||
class PluginDownloadFileRequestApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@plugin_data(payload_type=RequestRequestDownloadFile)
|
||||
@inner_api_ns.doc("plugin_download_file_request")
|
||||
@inner_api_ns.doc(description="Request signed URL for file download through plugin interface")
|
||||
@inner_api_ns.doc(
|
||||
responses={
|
||||
200: "Signed URL generated successfully",
|
||||
401: "Unauthorized - invalid API key",
|
||||
404: "Service not available",
|
||||
}
|
||||
)
|
||||
def post(self, payload: RequestRequestDownloadFile):
|
||||
"""Resolve signed download metadata for trusted external runtimes.
|
||||
|
||||
Unlike end-user-facing upload/download APIs, this inner endpoint serves
|
||||
trusted callers such as the ``dify-agent`` back proxy. The caller sends
|
||||
flattened ``tenant_id`` / ``user_id`` / ``user_from`` / ``invoke_from``
|
||||
context explicitly in the body, and ``FileRequestService`` rebuilds the
|
||||
corresponding ``FileAccessScope`` before resolving the signed URL.
|
||||
|
||||
The response is control-plane metadata only: filename, mime type, size,
|
||||
and the signed download URL. File bytes still flow through the existing
|
||||
signed file endpoints rather than through this inner API.
|
||||
"""
|
||||
tenant_model = db.session.get(Tenant, payload.tenant_id)
|
||||
if tenant_model is None:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
result = FileRequestService().request_download_url(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=payload.user_id,
|
||||
user_from=payload.user_from,
|
||||
invoke_from=payload.invoke_from,
|
||||
file_mapping=payload.file.model_dump(mode="python", exclude_none=True),
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(
|
||||
data={
|
||||
"filename": result.filename,
|
||||
"mime_type": result.mime_type,
|
||||
"size": result.size,
|
||||
"download_url": result.download_url,
|
||||
}
|
||||
).model_dump()
|
||||
|
||||
|
||||
@inner_api_ns.route("/fetch/app/info")
|
||||
class PluginFetchAppInfoApi(Resource):
|
||||
@get_user_tenant
|
||||
|
||||
@ -45,6 +45,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
.limit(1)
|
||||
)
|
||||
else:
|
||||
# Try id first (preserves the original "explicit end-user
|
||||
# id → that specific user" semantics for callers that pass
|
||||
# a known EndUser.id). Fall back to session_id so daemon-
|
||||
# supplied session UUIDs dedup against the row created on
|
||||
# the first Reverse Invocation call — without this, an
|
||||
# id-only lookup never matched (create writes user_id to
|
||||
# session_id, id is auto-generated) and a fresh EndUser
|
||||
# was created per call, breaking multi-turn chat
|
||||
# continuation (see #36736).
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
@ -53,6 +62,15 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if user_model is None:
|
||||
user_model = session.scalar(
|
||||
select(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
|
||||
@ -7,7 +7,7 @@ from hmac import new as hmac_new
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from core.db.session_factory import session_factory
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
@ -44,6 +44,8 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Inject an EndUser for valid inner API HMAC auth, otherwise pass the request through unchanged."""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
@ -72,9 +74,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
|
||||
if signature_base64 != token:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
kwargs["user"] = db.session.get(EndUser, user_id)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
with session_factory.create_session() as session:
|
||||
kwargs["user"] = session.get(EndUser, user_id)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -37,6 +37,13 @@ from controllers.openapi._models import (
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
MemberListQuery,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
MemberRoleUpdatePayload,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
@ -63,6 +70,9 @@ register_schema_models(
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
MemberInvitePayload,
|
||||
MemberListQuery,
|
||||
MemberRoleUpdatePayload,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
@ -86,6 +96,10 @@ register_response_schema_models(
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
MemberResponse,
|
||||
MemberListResponse,
|
||||
MemberInviteResponse,
|
||||
MemberActionResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from libs.helper import EmailStr, UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
@ -342,3 +342,61 @@ class ApprovalGrantClaimsPayload(BaseModel):
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
# Closed enum for invite/update-role payloads. Owner is intentionally not
|
||||
# assignable through these endpoints — ownership transfer goes through the
|
||||
# console's three-step email-verification flow.
|
||||
MemberAssignableRole = Literal["normal", "admin"]
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class MemberListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[MemberResponse]
|
||||
|
||||
|
||||
class MemberListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
email: EmailStr
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberRoleUpdatePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberInviteResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
email: str
|
||||
role: str
|
||||
member_id: str
|
||||
invite_url: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class MemberActionResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
@ -17,18 +17,17 @@ from controllers.openapi._models import (
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
Scope,
|
||||
TokenType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -42,32 +41,18 @@ from services.oauth_device_flow import (
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
|
||||
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
|
||||
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
subject_type="account",
|
||||
subject_email=account.email if account else None,
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
@ -77,19 +62,17 @@ class AccountApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
@ -122,10 +105,9 @@ class AccountSessionsApi(Resource):
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource):
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
@ -16,7 +16,8 @@ import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -124,8 +125,9 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
@ -158,8 +160,9 @@ class AppRunApi(Resource):
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,9 +1,4 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
"""GET /openapi/v1/apps and per-app reads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -28,31 +23,17 @@ from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
@ -99,8 +76,7 @@ class AppReadResource(Resource):
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
return app
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict:
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource):
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
@auth_router.guard_workspace(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
@ -237,7 +210,7 @@ class AppListApi(Resource):
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
|
||||
@ -18,37 +18,27 @@ from controllers.openapi._models import (
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
__all__ = ["auth_router"]
|
||||
|
||||
@ -1,46 +1,75 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_EE,
|
||||
HAS_ALLOWED_ROLES,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED,
|
||||
WORKSPACE_SCOPED,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.prepare import (
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
load_tenant_from_request,
|
||||
load_workspace_role,
|
||||
resolve_external_user,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_api_enabled,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
check_workspace_member,
|
||||
check_workspace_mismatch,
|
||||
check_workspace_role,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(WORKSPACE_MEMBERSHIP_REQUIRED, then=load_tenant_from_request),
|
||||
load_account,
|
||||
When(WORKSPACE_SCOPED, then=load_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(WORKSPACE_SCOPED, then=check_workspace_member),
|
||||
When(PATH_HAS_APP_ID, then=check_workspace_mismatch),
|
||||
When(HAS_ALLOWED_ROLES, then=check_workspace_role),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
external_sso_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(PATH_HAS_APP_ID, then=resolve_external_user),
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
When(PATH_HAS_APP_ID, then=check_app_api_enabled),
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
auth_router = PipelineRouter(
|
||||
{
|
||||
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
|
||||
}
|
||||
)
|
||||
|
||||
60
api/controllers/openapi/auth/conditions.py
Normal file
60
api/controllers/openapi/auth/conditions.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
CondFn = Callable[[RequestContext, AuthData | None], bool]
|
||||
|
||||
|
||||
class Cond:
|
||||
def __init__(self, fn: CondFn) -> None:
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self._fn(ctx, data)
|
||||
|
||||
def __and__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
|
||||
|
||||
def __or__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
|
||||
|
||||
def __invert__(self) -> Cond:
|
||||
return Cond(lambda ctx, data: not self(ctx, data))
|
||||
|
||||
|
||||
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
|
||||
return Cond(lambda ctx, _: fn(ctx))
|
||||
|
||||
|
||||
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
|
||||
return Cond(lambda _, data: data is not None and fn(data))
|
||||
|
||||
|
||||
def config_cond(fn: Callable[[], bool]) -> Cond:
|
||||
return Cond(lambda _, __: fn())
|
||||
|
||||
|
||||
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
|
||||
|
||||
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
|
||||
|
||||
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
|
||||
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
|
||||
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
WORKSPACE_MEMBERSHIP_REQUIRED = request_cond(lambda ctx: ctx.workspace_membership)
|
||||
HAS_ALLOWED_ROLES = request_cond(lambda ctx: ctx.allowed_roles is not None)
|
||||
|
||||
# Caller must belong to the resolved tenant: either an app-scoped path (tenant
|
||||
# from the app) or an explicit workspace-membership path (tenant from request).
|
||||
WORKSPACE_SCOPED = PATH_HAS_APP_ID | WORKSPACE_MEMBERSHIP_REQUIRED
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
@ -1,68 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
75
api/controllers/openapi/auth/data.py
Normal file
75
api/controllers/openapi/auth/data.py
Normal file
@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant, TenantAccountRole
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
class Edition(StrEnum):
|
||||
CE = "ce"
|
||||
EE = "ee"
|
||||
SAAS = "saas"
|
||||
|
||||
|
||||
def current_edition() -> Edition:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
return Edition.SAAS
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return Edition.EE
|
||||
return Edition.CE
|
||||
|
||||
|
||||
class ExternalIdentity(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
email: str
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
workspace_membership: bool = False
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
required_scope: Scope | None = None
|
||||
token_type: TokenType
|
||||
account_id: uuid.UUID | None = None
|
||||
token_hash: str
|
||||
token_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope]
|
||||
tenants: dict[str, bool] = Field(default_factory=dict)
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
tenant_role: TenantAccountRole | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]:
|
||||
if self.app is None or self.caller is None or self.caller_kind is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app context missing")
|
||||
return self.app, self.caller, self.caller_kind
|
||||
19
api/controllers/openapi/auth/flow.py
Normal file
19
api/controllers/openapi/auth/flow.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
|
||||
self.condition = condition
|
||||
self._step = then
|
||||
|
||||
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self.condition(ctx, data)
|
||||
|
||||
def __call__(self, arg: Any) -> None:
|
||||
self._step(arg)
|
||||
@ -1,51 +1,262 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
"""Auth pipeline — entry point for all openapi auth.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
`PipelineRouter.guard()` is the only attachment point for endpoints.
|
||||
`AuthPipeline` is a pure step-runner with no routing concerns.
|
||||
`PipelineRoute` binds a pipeline to optional edition requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
TokenType,
|
||||
extract_bearer,
|
||||
get_authenticator,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models.account import TenantAccountRole
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
class AuthPipeline:
|
||||
"""Pure step-runner — no routing, no guard.
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
Both `prepare` and `auth` steps receive the same `AuthData` instance.
|
||||
`prepare` steps populate it; `auth` steps validate it.
|
||||
"""
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
def __init__(self, prepare: list, auth: list) -> None:
|
||||
self._prepare = prepare
|
||||
self._auth = auth
|
||||
|
||||
def _run(
|
||||
self,
|
||||
identity: AuthContext,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
token_type=identity.token_type,
|
||||
account_id=identity.account_id,
|
||||
token_hash=identity.token_hash,
|
||||
token_id=identity.token_id,
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
allowed_roles=allowed_roles,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
if identity.subject_email
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
for step in self._prepare:
|
||||
if _should_run(step, req_ctx, data=None):
|
||||
step(data)
|
||||
|
||||
for step in self._auth:
|
||||
if _should_run(step, req_ctx, data=data):
|
||||
step(data)
|
||||
|
||||
reset_token = set_auth_ctx(identity)
|
||||
if data.caller:
|
||||
_mount_flask_login(data.caller)
|
||||
|
||||
try:
|
||||
kwargs["auth_data"] = data
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
reset_auth_ctx(reset_token)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRoute:
|
||||
pipeline: AuthPipeline
|
||||
required_edition: frozenset[Edition] | None = None
|
||||
|
||||
|
||||
class PipelineRouter:
|
||||
"""Entry point for openapi auth.
|
||||
|
||||
`guard()` is the decorator that endpoints attach to. It applies
|
||||
global gates (edition, token type) then dispatches to the matching
|
||||
`PipelineRoute` for the token type.
|
||||
"""
|
||||
|
||||
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
|
||||
self._routes = routes
|
||||
|
||||
def guard(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def guard_workspace(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Callable:
|
||||
return self._make_decorator(
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=True,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
def _make_decorator(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._execute(
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
workspace_membership: bool = False,
|
||||
allowed_roles: frozenset[TenantAccountRole] | None = None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
raise NotFound()
|
||||
|
||||
license_checked = False
|
||||
if edition is not None and Edition.EE in edition:
|
||||
_check_license()
|
||||
license_checked = True
|
||||
|
||||
token = extract_bearer(request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
identity = get_authenticator().authenticate(token)
|
||||
|
||||
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
|
||||
emit_wrong_surface(
|
||||
subject_type=_subject_type_str(identity),
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(identity, "client_id", None),
|
||||
token_id=str(identity.token_id) if identity.token_id else None,
|
||||
)
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
route = self._routes.get(identity.token_type)
|
||||
if route is None:
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
if route.required_edition is not None:
|
||||
if current_edition() not in route.required_edition:
|
||||
raise Forbidden("external_sso_requires_ee")
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(
|
||||
identity,
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
workspace_membership=workspace_membership,
|
||||
allowed_roles=allowed_roles,
|
||||
)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
if isinstance(step, When):
|
||||
return step.applies(req_ctx, data)
|
||||
return True
|
||||
|
||||
|
||||
def _subject_type_str(identity: Any) -> str | None:
|
||||
subject = getattr(identity, "subject_type", None)
|
||||
if subject is None:
|
||||
return None
|
||||
return subject.value if hasattr(subject, "value") else str(subject)
|
||||
|
||||
|
||||
def _check_license() -> None:
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
|
||||
raise Forbidden("license_invalid")
|
||||
|
||||
|
||||
def _mount_flask_login(user: Any) -> None:
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]
|
||||
|
||||
103
api/controllers/openapi/auth/prepare.py
Normal file
103
api/controllers/openapi/auth/prepare.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
if data.app is not None:
|
||||
return
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_tenant_from_request(data: AuthData) -> None:
|
||||
if data.tenant is not None:
|
||||
return
|
||||
workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise NotFound("workspace not found")
|
||||
try:
|
||||
uuid.UUID(workspace_id)
|
||||
except ValueError:
|
||||
raise NotFound("workspace not found")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise NotFound("workspace not found")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
if data.caller is not None:
|
||||
return
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
if data.tenant:
|
||||
account.current_tenant = data.tenant
|
||||
data.caller = account
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def load_workspace_role(data: AuthData) -> None:
|
||||
if data.tenant_role is not None:
|
||||
return
|
||||
if data.tenant is None or data.account_id is None:
|
||||
return
|
||||
if data.caller is not None and getattr(data.caller, "status", None) != "active":
|
||||
return
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(data.account_id), str(data.tenant.id))
|
||||
if role is None:
|
||||
return
|
||||
data.tenant_role = role
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=str(data.tenant.id),
|
||||
app_id=str(data.app.id),
|
||||
user_id=data.external_identity.email,
|
||||
)
|
||||
data.caller = end_user
|
||||
data.caller_kind = "end_user"
|
||||
|
||||
|
||||
def load_app_access_mode(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
try:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id))
|
||||
if settings is None:
|
||||
data.app_access_mode = None
|
||||
return
|
||||
data.app_access_mode = WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
data.app_access_mode = None
|
||||
@ -1,170 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
@ -1,168 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
105
api/controllers/openapi/auth/verify.py
Normal file
105
api/controllers/openapi/auth/verify.py
Normal file
@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def check_scope(data: AuthData) -> None:
|
||||
if data.required_scope is None:
|
||||
return
|
||||
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_workspace_member(data: AuthData) -> None:
|
||||
"""Assert the caller belongs to the resolved tenant.
|
||||
|
||||
`load_workspace_role` stashes the membership role (None when the caller is
|
||||
not a member or is inactive). A missing membership surfaces as 404, not
|
||||
403, so workspace IDs don't leak across tenants.
|
||||
"""
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
|
||||
def check_workspace_mismatch(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
request_workspace_id = data.path_params.get("workspace_id") or request.args.get("workspace_id")
|
||||
if request_workspace_id and request_workspace_id != str(data.tenant.id):
|
||||
raise UnprocessableEntity("workspace_id does not match app's workspace")
|
||||
|
||||
|
||||
def check_workspace_role(data: AuthData) -> None:
|
||||
if data.allowed_roles is None:
|
||||
return
|
||||
if data.tenant_role is None:
|
||||
raise NotFound("workspace not found")
|
||||
if data.tenant_role not in data.allowed_roles:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
|
||||
def check_app_api_enabled(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
if not data.app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
|
||||
TokenType.OAUTH_ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def check_acl(data: AuthData) -> None:
|
||||
if data.app is None or data.app_access_mode is None:
|
||||
raise Forbidden("app or access mode not loaded")
|
||||
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
|
||||
if data.app_access_mode not in allowed_modes:
|
||||
raise Forbidden("subject_not_allowed_for_access_mode")
|
||||
|
||||
|
||||
def check_private_app_permission(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise Forbidden("app not loaded")
|
||||
user_id = _resolve_user_id(data)
|
||||
if user_id is None:
|
||||
raise Forbidden("cannot resolve user for private app check")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id):
|
||||
raise Forbidden("user_not_allowed_for_private_app")
|
||||
|
||||
|
||||
def _resolve_user_id(data: AuthData) -> str | None:
|
||||
if data.token_type == TokenType.OAUTH_ACCOUNT:
|
||||
return str(data.account_id) if data.account_id is not None else None
|
||||
if data.external_identity is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
|
||||
return str(account.id) if account is not None else None
|
||||
@ -17,11 +17,11 @@ from controllers.common.errors import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -39,8 +39,9 @@ class AppFileUploadApi(Resource):
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user