Compare commits

..

156 Commits

Author SHA1 Message Date
44ae313dae Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-22 10:22:47 +08:00
67a17c9924 [autofix.ci] apply automated fixes 2026-05-21 07:55:24 +00:00
6696e49df3 Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-21 15:50:39 +08:00
bd94db64b7 test(web): fix test of member selector in delivery method configure 2026-05-21 09:49:36 +08:00
f0fef15d2a chore(api): rebase migration 2026-05-21 09:21:54 +08:00
e49607f545 Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-21 09:07:30 +08:00
52db8b7e72 merge main 2026-05-20 15:37:57 +08:00
6cbc97590d chore(web): use tailwind v4 style 2026-05-20 15:31:26 +08:00
49fa6aa975 merge main 2026-05-20 15:24:15 +08:00
418049090f fix(web): fix type check 2026-05-19 16:27:32 +08:00
1150f492aa test(web): tests coverage of human input 2026-05-19 16:10:56 +08:00
3f8f7a9b87 chore(web): fix knip 2026-05-19 15:32:21 +08:00
205ea1bd98 fix(web): lint error in form status card 2026-05-19 15:19:00 +08:00
e26042cd90 [autofix.ci] apply automated fixes 2026-05-19 06:40:55 +00:00
8626e50ee6 test(web): fix test of human input node content 2026-05-19 14:36:40 +08:00
beafa0308e Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-19 14:34:03 +08:00
8d200afb32 fix(web): fix leaked conditional rendering
Replace `&&` with ternary operator.

ref: https://www.eslint-react.xyz/docs/rules/no-leaked-conditional-rendering
2026-05-19 11:08:06 +08:00
41a88a6580 [autofix.ci] apply automated fixes 2026-05-19 02:57:08 +00:00
45c20a24f2 Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-19 09:14:48 +08:00
0b4348ee1c [autofix.ci] apply automated fixes 2026-05-18 07:45:54 +00:00
d2c5539937 Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-18 15:41:40 +08:00
ca9e738d53 Merge remote-tracking branch 'upstream/main' into codex/upgrade-graphon-0-4-0 2026-05-18 15:31:05 +08:00
423edab679 fix(web): input field overflow 2026-05-16 13:30:48 +08:00
e906ad1289 fix(web): add input fields 2026-05-15 18:20:34 +08:00
1c1970e47b docs(api): add a design docs for human input form page file upload 2026-05-15 14:47:12 +08:00
450e19b46c fix(api): ensure select options are properly substitued 2026-05-15 10:19:29 +08:00
f69779fc0c chore(api): linearize migration 2026-05-14 16:59:13 +08:00
fdc20e460b fix(api): fix tests 2026-05-14 16:58:45 +08:00
68ea88c82d fix(api): adapt graphon changes and fix tests 2026-05-14 15:45:51 +08:00
f045359e67 chore(api): adapt new graphon api 2026-05-14 15:26:22 +08:00
4ff95853c2 fix(api): resolve_variable_select_input_options should raise type error for invalid variable 2026-05-14 15:24:36 +08:00
c96d108198 test(api): ensure select values are replaced in events 2026-05-14 14:47:23 +08:00
c067498f73 chore(api): upgrade graphon to v0.4.0 2026-05-14 14:29:27 +08:00
da90c934c0 test(api): add a test to ensure extra_contents exists 2026-05-14 14:21:47 +08:00
226ce25596 chore(api): fix typing errors 2026-05-14 14:21:47 +08:00
e897e275ef fix(api): change values for select input to runtime values 2026-05-14 14:21:47 +08:00
5f1b47c21b fix(api): fix missing extra_contents in HITL forms 2026-05-14 14:21:47 +08:00
186fd60712 Merge branch 'main' into tp 2026-05-13 14:49:47 +08:00
221a5020ad fix(web): z-index of add input field 2026-05-13 14:30:13 +08:00
51b3e63472 fix(web): add input field crash 2026-05-13 14:14:53 +08:00
070f8bcf14 Merge branch 'main' into tp 2026-05-13 11:15:58 +08:00
698bb902ca Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement 2026-05-13 10:03:25 +08:00
9a65969744 test(api): inject file_reference_factory to HumanInputNode in test 2026-05-13 10:01:13 +08:00
7facbfcb71 test(api): introduce a test for HumanInputNode initialization 2026-05-13 10:00:47 +08:00
fce6d017d6 test(api): renaming import according to graphon changes 2026-05-13 10:00:23 +08:00
f9729e8764 fix(api): fix HumanInputNode initialization error 2026-05-13 09:59:41 +08:00
b3c007461d chore(api): upgrade grpahon dependency 2026-05-13 09:59:22 +08:00
3169ae7f53 Merge branch 'main' into tp 2026-05-12 18:07:52 +08:00
ac05dc9fa3 fix(web): readonly state of form content 2026-05-12 18:05:33 +08:00
3f372352d8 fix(web): input fields variable name can not be duplicated 2026-05-12 17:27:14 +08:00
c86fd4cba0 fix(web): tests for component-ui & single-run-form 2026-05-12 16:28:50 +08:00
26c14fd58a Merge branch 'main' into tp 2026-05-12 15:55:20 +08:00
8bb70e78b7 fix(web): dynamic select display in form content 2026-05-12 15:05:15 +08:00
94ad8674d6 fix(web): form content action button disable state 2026-05-12 15:05:15 +08:00
998201f6e3 fix(web): dynamic select display 2026-05-12 15:05:15 +08:00
17759c0b80 Merge remote-tracking branch 'upstream/main' into feat/hitl-form-enhancement 2026-05-12 15:04:10 +08:00
3d445e1d95 Merge branch 'main' into tp 2026-05-09 12:03:24 +08:00
1e6700b679 chore(api): replace graphon with development branch 2026-05-09 10:39:34 +08:00
4b5b00ce63 Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement 2026-05-09 10:38:24 +08:00
07fea50216 fix(web): remote file upload 2026-05-09 10:36:16 +08:00
f1833fdb08 Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement 2026-05-09 10:04:43 +08:00
e9070daaaa Merge branch 'main' into tp 2026-05-09 09:49:36 +08:00
0e389d223f fix(web): email input 2026-05-09 09:49:17 +08:00
132f80dd9e Merge branch 'main' into tp 2026-05-09 08:12:29 +08:00
e099ba8679 test(api): add tests for delivery and file inputs 2026-05-09 02:13:37 +08:00
c4b2985361 test(api): fix broken HITL tests 2026-05-09 02:12:46 +08:00
343982bd46 chore(api): improve documentation for HumanInputFormSubmitPayload 2026-05-09 02:10:54 +08:00
bba7001de2 fix(api): fix file uploading for delivery test form 2026-05-09 01:44:34 +08:00
62efb66a2f Merge branch 'main' into tp 2026-05-08 18:02:14 +08:00
e3e4f77de1 Merge branch 'main' into tp 2026-05-08 17:36:17 +08:00
a75208b432 Merge branch 'main' into tp 2026-05-08 16:53:27 +08:00
eadaaa1aa0 fix(web): z-index of shortcut popup 2026-05-08 16:53:12 +08:00
74c4d720d4 test(web): fix tests of new file uploader 2026-05-08 16:02:39 +08:00
c8d6ad117e Merge branch 'main' into tp 2026-05-08 15:25:47 +08:00
7133754a31 feat(api): bind UploadFile to workflow initiator in unauthenticated form submission
The basic assumption of Workflow execution for now is that only one user
(`Account` or `EndUser`) participate the workflow execution. For
unauthenticated form submission this assumption does not hold. Binding
the uploaded file to worfklow initiator aligns with current implementation.

For auditing the actual uploading recipient, a dedicated table
`HumanInputFormUploadFile` is introduced to record the uploading
behavior.
2026-05-08 14:32:51 +08:00
f1adc60822 Merge branch 'main' into tp 2026-05-08 13:35:26 +08:00
58af8aa7fe refactor(api): use TypedDict to model file mapping 2026-05-08 11:58:16 +08:00
ed98925f11 Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement 2026-05-08 11:44:37 +08:00
a9de4bd96b fix(web): form in email test sender 2026-05-08 11:40:56 +08:00
d65cc21e85 fix(web): support dynamic selector in human input step run & email configure 2026-05-08 11:06:32 +08:00
23d39beeed fix(web): add email configuration check in human input node 2026-05-08 10:41:55 +08:00
3f35f3594b fix(web): form submmit in human input form page 2026-05-08 10:32:53 +08:00
04ed797ac9 merge main 2026-05-08 10:22:35 +08:00
4c386e3ea7 fix(api): fix missing import and name error 2026-05-07 16:47:41 +08:00
3f6559dd60 Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement 2026-05-07 16:44:29 +08:00
c0bedd9118 Merge branch 'main' into tp 2026-05-07 14:18:50 +08:00
b73a4d2700 test(api): add file and file list form type in resumption test 2026-05-07 10:58:51 +08:00
02e42fd66f chore(api): update mock path in tests 2026-05-07 10:53:38 +08:00
a0f8db5516 test(api): add tests about file input file uploading api 2026-05-07 10:52:54 +08:00
37681bce8c test(api): add tests about submission response 2026-05-07 10:51:51 +08:00
23e59c6778 test(api): update import names in tests
Upstream project graphon renamed some classes. Modify the tests to keep
import names consistent with upstream.
2026-05-07 10:47:44 +08:00
51e181c588 feat(api): introduce file upload apis for human input page 2026-05-07 10:43:00 +08:00
d6f607f6e7 feat(api): expose sumitted_data to frontend 2026-05-07 10:40:52 +08:00
cd91757623 Merge branch 'main' into tp 2026-05-07 10:16:57 +08:00
651dfe5dca feat(web): new upload api for human input form page 2026-05-06 17:54:41 +08:00
21a9c8d59c Merge branch 'main' into tp 2026-05-06 15:39:10 +08:00
4fce9ee8e5 Merge branch 'main' into tp 2026-04-27 21:41:30 +08:00
90ab734a05 chore(web): replace form_data with submitted_data 2026-04-27 21:33:06 +08:00
ad43a46c37 Merge branch 'main' into tp 2026-04-27 15:57:23 +08:00
e5b5c1fa3b Merge branch 'main' into tp 2026-04-27 13:49:30 +08:00
6e4fa39db5 Merge branch 'main' into tp 2026-04-27 11:16:57 +08:00
e994009476 Revert "fix(web): file_list type"
This reverts commit 46747993d4.
2026-04-27 11:13:05 +08:00
4da8afaed5 fix(web): use number_limits for file_list type 2026-04-27 10:38:51 +08:00
bdecea34a3 Merge branch 'main' into tp 2026-04-27 10:20:08 +08:00
9ad5d89b07 feat(web): Use number_limits for file uploading limit
Align with the current form definition in start node.

Co-Authored-By: GPT 5.4 <codex@openai.com>
2026-04-27 07:12:21 +08:00
74f17d0ec8 refactor(api): rename form definitions fields 2026-04-26 10:51:49 +08:00
7a12d46a45 refactor(web): human input form page 2026-04-24 16:21:32 +08:00
cec437b35b fix(web): fix human input form filled UI 2026-04-24 15:23:40 +08:00
1c5d877372 fix(web): human input form content submittion 2026-04-24 11:21:58 +08:00
a8e663863d fix(web): human input step run preview restriction 2026-04-24 11:06:58 +08:00
60f577fd11 fix(web): step run of file uploader 2026-04-24 10:56:19 +08:00
46747993d4 fix(web): file_list type 2026-04-24 10:27:55 +08:00
b8481f6d6f fix(web): fix rebase error 2026-04-23 19:34:45 +08:00
8ce356c98f fix(web): fix node handle 2026-04-23 19:27:14 +08:00
d72794bc67 fix(web): form content preview 2026-04-23 18:56:41 +08:00
7316a1be2b fix(web): style issue 2026-04-23 18:56:41 +08:00
1a0776ce9b fix(web): type select change 2026-04-23 18:56:41 +08:00
7c348e994c fix(web): form content style issues 2026-04-23 18:56:41 +08:00
1a3c1a9b32 fix(web): node handle fix 2026-04-23 18:56:38 +08:00
6263bb1b74 Add missing zh-Hans human input labels 2026-04-23 18:55:34 +08:00
d494e42166 Test human input form submissions across entry points 2026-04-23 18:55:34 +08:00
5d20502494 Test status-agnostic human input history parsing 2026-04-23 18:55:34 +08:00
75e5429534 Test human input submitted markdown fallback 2026-04-23 18:55:34 +08:00
ccb43eb856 Cover file-based human input content items 2026-04-23 18:55:34 +08:00
4f63b2b162 Cover file type transitions in human input tests 2026-04-23 18:55:34 +08:00
1b45cdb08a Expand human input utils initialization tests 2026-04-23 18:55:34 +08:00
cfb501c1f6 Localize human input field configuration labels 2026-04-23 18:55:34 +08:00
703228ffed Cover structured human input panel rendering 2026-04-23 18:55:34 +08:00
cbe2f66f1b Track human input selector variable references 2026-04-23 18:55:34 +08:00
b879748ba0 Parse human input extra contents by payload 2026-04-23 18:55:34 +08:00
29cad80e62 Prefer structured submitted human input data 2026-04-23 18:55:34 +08:00
2ad35e56bc Add typed submitted human input renderer 2026-04-23 18:55:34 +08:00
add9260e58 Align human input submission payload types 2026-04-23 18:55:34 +08:00
94b8f8f170 Preserve typed human input values in share form 2026-04-23 18:55:34 +08:00
67c832e60e Allow typed single-run human input submissions 2026-04-23 18:55:34 +08:00
e9fb3bd751 Cover non-string human input chat submissions 2026-04-23 18:55:34 +08:00
80ede7cdb5 Initialize human input values by field type 2026-04-23 18:55:34 +08:00
5309b56225 Use shared renderer for human input content 2026-04-23 18:55:34 +08:00
8d3ddee7d3 Extract shared human input field renderer 2026-04-23 18:55:34 +08:00
c2fd595a82 Fix human input typing regressions 2026-04-23 18:55:34 +08:00
85d05f5113 Preview human input field types in markdown 2026-04-23 18:55:34 +08:00
ccf61c0372 Summarize human input field configurations 2026-04-23 18:55:34 +08:00
49195fffdd Configure multi-file human input fields 2026-04-23 18:55:34 +08:00
4002d58171 Configure single-file human input fields 2026-04-23 18:55:34 +08:00
cae5315c18 Guard select fields against default leakage 2026-04-23 18:55:34 +08:00
eb2eefdbb5 Constrain select option variables to string arrays 2026-04-23 18:55:34 +08:00
82d410325b Support variable-backed select options 2026-04-23 18:55:34 +08:00
00f0f6d040 Add constant select options editor 2026-04-23 18:55:34 +08:00
022d73d0ed Lock paragraph prepopulate behavior 2026-04-23 18:55:34 +08:00
71803d7c76 Add human input field type selector 2026-04-23 18:55:34 +08:00
93945d603e Show typed human input outputs in panel 2026-04-23 18:55:34 +08:00
acd1641d16 Infer human input output variable types 2026-04-23 18:55:34 +08:00
63711bf1dc Refine human input extra content types 2026-04-23 18:55:34 +08:00
37f79ee5d1 Extend human input runtime data types 2026-04-23 18:55:34 +08:00
5faa4f9520 Define detailed human input field types 2026-04-23 18:55:34 +08:00
6c4f293719 Refactor human input form item types 2026-04-23 18:55:34 +08:00
1248 changed files with 21954 additions and 60083 deletions

View File

@ -1,15 +0,0 @@
**/node_modules
**/.pnpm-store
**/dist
**/.next
**/.turbo
**/.cache
**/__pycache__
**/*.pyc
**/.mypy_cache
**/.ruff_cache
.git
.github
*.md
!web/README.md
!api/README.md

4
.gitattributes vendored
View File

@ -5,7 +5,3 @@
# them.
*.sh text eol=lf
# Codegen output must stay byte-identical across platforms so
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
*.generated.ts text eol=lf

4
.github/CODEOWNERS vendored
View File

@ -18,10 +18,6 @@
# Docs
/docs/ @crazywoola
# CLI
/cli/ @langgenius/maintainers
/.github/workflows/cli-tests.yml @langgenius/maintainers
# Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost

111
.github/dependabot.yml vendored
View File

@ -110,114 +110,3 @@ updates:
github-actions-dependencies:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "github-actions"
directory: "/"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 5
schedule:
interval: "weekly"
groups:
github-actions-dependencies:
patterns:
- "*"

View File

@ -1,88 +0,0 @@
name: CLI Release
on:
workflow_dispatch:
push:
tags:
- 'difyctl-v*'
concurrency:
group: cli-release-${{ github.ref }}
cancel-in-progress: true
jobs:
release:
name: build standalone binaries (all targets)
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
permissions:
contents: write
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
fetch-depth: 0
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Setup Bun
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
with:
bun-version: latest
- name: Read cli/package.json
id: manifest
run: |
version=$(node -p "require('./package.json').version")
channel=$(node -p "require('./package.json').difyctl.channel")
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
{
echo "version=$version"
echo "channel=$channel"
echo "minDify=$minDify"
echo "maxDify=$maxDify"
} >> "$GITHUB_OUTPUT"
- name: Validate manifest
run: scripts/release-validate-manifest.sh
- name: Install cross-arch native prebuilds
# Re-installs node_modules with every @napi-rs/keyring platform variant
# so `bun build --compile` can embed the right .node into each target.
working-directory: ./
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
- name: Compile standalone binaries (all targets)
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
run: |
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
pnpm build:bin
- name: Generate sha256 checksum file
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
run: scripts/release-write-checksums.sh
- name: Publish GitHub Release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
name: difyctl ${{ steps.manifest.outputs.version }}
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
generate_release_notes: true
fail_on_unmatched_files: true
files: |
cli/dist/bin/difyctl-v*

View File

@ -1,60 +0,0 @@
name: CLI Smoke (live dify)
on:
workflow_dispatch:
inputs:
dify_version:
description: "Dify image tag to test against (e.g. 1.7.0)"
type: string
required: true
cli_ref:
description: "Git ref to build the cli from (default: current branch)"
type: string
required: false
permissions:
contents: read
jobs:
smoke:
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout cli ref
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Bring up dify
env:
DIFY_VERSION: ${{ inputs.dify_version }}
run: |
cd docker
cp .env.example .env
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
docker compose up -d api worker web db redis
for i in $(seq 1 60); do
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
echo "dify api ready after ${i}s"
break
fi
sleep 1
done
- name: Run smoke against live dify
working-directory: ./cli
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
- name: Dump dify logs on failure
if: failure()
run: |
cd docker
docker compose logs api worker web --tail=200

View File

@ -1,46 +0,0 @@
name: CLI Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: cli-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: CLI Tests
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: CI pipeline (typecheck, lint, coverage, build)
run: pnpm ci
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: cli/coverage
flags: cli
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -42,7 +42,6 @@ jobs:
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
cli-changed: ${{ steps.changes.outputs.cli }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
@ -63,18 +62,6 @@ jobs:
- 'docker/generate_docker_compose'
- 'docker/ssrf_proxy/**'
- 'docker/volumes/sandbox/conf/**'
cli:
- 'cli/**'
- 'packages/tsconfig/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- 'eslint.config.mjs'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/cli-tests.yml'
- '.github/workflows/cli-docker-build.yml'
- '.github/actions/setup-web/**'
web:
- 'web/**'
- 'packages/**'
@ -197,66 +184,6 @@ jobs:
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
cli-tests-run:
name: Run CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
uses: ./.github/workflows/cli-tests.yml
secrets: inherit
cli-tests-skip:
name: Skip CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped CLI tests
run: echo "No CLI-related changes detected; skipping CLI tests."
cli-tests:
name: CLI Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- cli-tests-run
- cli-tests-skip
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize CLI Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
RUN_RESULT: ${{ needs.cli-tests-run.result }}
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "CLI tests ran successfully."
exit 0
fi
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "CLI tests were skipped because no CLI-related files changed."
exit 0
fi
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run:
name: Run Web Tests
needs:

9
.gitignore vendored
View File

@ -115,12 +115,6 @@ venv/
ENV/
env.bak/
venv.bak/
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
!/cli/src/env/
!/cli/src/commands/env/
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
!/cli/scripts/lib/
.conda/
# Spyder project settings
@ -253,9 +247,8 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
*.local.toml
# Code Agent Folder
.qoder/*
.context/
.context/*
.eslintcache

View File

@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration

View File

@ -17,7 +17,7 @@ FROM base AS packages
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# basic environment
g++ \
git g++ \
# for building gmpy2
libmpfr-dev libmpc-dev

View File

@ -159,7 +159,6 @@ def initialize_extensions(app: DifyApp):
ext_logstore,
ext_mail,
ext_migrate,
ext_oauth_bearer,
ext_orjson,
ext_otel,
ext_proxy_fix,
@ -204,7 +203,6 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
ext_oauth_bearer,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -30,7 +30,7 @@ from clients.agent_backend.factory import create_agent_backend_run_client
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
from clients.agent_backend.request_builder import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_PLUGIN_CONTEXT_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendModelConfig,
@ -42,7 +42,7 @@ from clients.agent_backend.request_builder import (
__all__ = [
"AGENT_SOUL_PROMPT_LAYER_ID",
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendError",

View File

@ -4,9 +4,7 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
protocol has a single owner. API-only context such as Agent Soul vs workflow job
prompt is preserved in layer names and metadata until the dedicated product
schemas land in later phases. Dify-owned execution identifiers are emitted as an
explicit ``dify.execution_context`` layer so the run request stays fully
composition-driven.
schemas land in later phases.
"""
from __future__ import annotations
@ -17,19 +15,18 @@ from agenton.compositor import CompositorSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLayerConfig,
DifyPluginLLMLayerConfig,
)
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
RunComposition,
RunLayerSpec,
@ -40,15 +37,17 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
tenant_id: str
plugin_id: str
model_provider: str
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
@ -56,14 +55,10 @@ class AgentBackendModelConfig(BaseModel):
class AgentBackendOutputConfig(BaseModel):
"""API-side structured output declaration for the conventional output layer.
The structured-output tool name is fixed to ``final_output`` inside
``dify_agent.layers.output`` so callers only control the JSON Schema plus
optional description/strictness metadata.
"""
"""API-side structured output declaration for the conventional output layer."""
json_schema: dict[str, JsonValue]
name: str = "final_result"
description: str | None = None
strict: bool | None = None
@ -74,7 +69,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
execution_context: ExecutionContext
workflow_node_job_prompt: str
user_prompt: str
agent_soul_prompt: str | None = None
@ -126,18 +121,21 @@ class AgentBackendRunRequestBuilder:
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
type=DIFY_PLUGIN_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
config=DifyPluginLayerConfig(
tenant_id=run_input.model.tenant_id,
plugin_id=run_input.model.plugin_id,
user_id=run_input.model.user_id,
),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
@ -155,6 +153,7 @@ class AgentBackendRunRequestBuilder:
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
name=run_input.output.name,
description=run_input.output.description,
strict=run_input.output.strict,
),
@ -163,6 +162,7 @@ class AgentBackendRunRequestBuilder:
return CreateRunRequest(
composition=RunComposition(layers=layers),
execution_context=run_input.execution_context,
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,

View File

@ -11,7 +11,6 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,3 @@
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
@ -25,7 +23,7 @@ class DeploymentConfig(BaseSettings):
default=False,
)
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
EDITION: str = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
)

View File

@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
default=60 * 60,
)
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching tenant plugin model providers in Redis",
default=60 * 60 * 24,
)
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
description="Maximum allowed size (bytes) for plugin-generated files",
default=50 * 1024 * 1024,
@ -525,44 +520,6 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
OPENAPI_ENABLED: bool = Field(
description=(
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
"programmatic clients. Set to true to activate; disabled by default."
),
validation_alias=AliasChoices("OPENAPI_ENABLED"),
default=False,
)
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
description=(
"Comma-separated allowlist for /openapi/v1/* CORS. "
"Default empty = same-origin only. Browser-cookie routes within "
"the group reject cross-origin OPTIONS regardless of this list."
),
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
default="",
)
@computed_field
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
description=(
"Comma-separated client_id values accepted at "
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
"without code changes. Unknown client_id returns 400 unsupported_client."
),
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
default="difyctl",
)
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
)
@ -938,17 +895,6 @@ class AuthConfig(BaseSettings):
default=86400,
)
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
class ModerationConfig(BaseSettings):
"""
@ -1235,14 +1181,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, override
from typing import Any
from pydantic import Field
from pydantic.fields import FieldInfo
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

View File

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Mapping
from typing import Any, override
from typing import Any
from pydantic.fields import FieldInfo
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
field_value = self.remote_configs.get(field_name)
if field_value is None:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, override, runtime_checkable
from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
self._config = config or {}
self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
self._extensions[name] = extension
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield

View File

@ -6,7 +6,7 @@ import contextvars
import threading
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final, override
from typing import Any, final
from flask import Flask, current_app, g
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
"""
self._flask_app = flask_app
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():

View File

@ -1,10 +1,40 @@
import json
from pydantic import BaseModel, JsonValue
from pydantic import BaseModel, Field, JsonValue
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
"decision": "approve",
"attachment": {
"transfer_method": "local_file",
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
"type": "document",
},
"attachments": [
{
"transfer_method": "local_file",
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
"type": "document",
},
{
"transfer_method": "remote_url",
"url": "https://example.com/report.pdf",
"type": "document",
},
],
}
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
inputs: dict[str, JsonValue] = Field(
description=(
"Submitted human input values keyed by output variable name. "
"Use a string for paragraph or select input values, a file mapping for file inputs, "
"and a list of file mappings for file-list inputs. Local file mappings use "
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
"`transfer_method=remote_url` with `url` or `remote_url`."
),
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
)
action: str

View File

@ -68,7 +68,6 @@ from .app import (
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_node_output_inspector,
workflow_run,
workflow_statistic,
workflow_trigger,
@ -219,7 +218,6 @@ __all__ = [
"workflow_app_log",
"workflow_comment",
"workflow_draft_variable",
"workflow_node_output_inspector",
"workflow_run",
"workflow_statistic",
"workflow_trigger",

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -82,7 +80,7 @@ class AgentRosterDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
def get(self, agent_id):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
@ -91,7 +89,7 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, agent_id: UUID):
def patch(self, agent_id):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return _agent_roster_service().update_roster_agent(
@ -102,7 +100,7 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, agent_id: UUID):
def delete(self, agent_id):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
return "", 204
@ -113,7 +111,7 @@ class AgentRosterVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
def get(self, agent_id):
_, tenant_id = current_account_with_tenant()
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
@ -123,7 +121,7 @@ class AgentRosterVersionDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID, version_id: UUID):
def get(self, agent_id, version_id):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,

View File

@ -1,5 +1,4 @@
from datetime import datetime
from uuid import UUID
import flask_restx
from flask_restx import Resource
@ -9,25 +8,18 @@ from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_response_schema_models
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import dump_response, to_timestamp
from libs.login import login_required
from models import Account
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from .wraps import account_initialization_required, edit_permission_required, setup_required
class ApiKeyItem(ResponseModel):
@ -47,7 +39,7 @@ class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
def _get_resource(resource_id, tenant_id, resource_model):
@ -71,11 +63,10 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
keys = db.session.scalars(
@ -83,14 +74,13 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
@edit_permission_required
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
_, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
current_key_count: int = (
db.session.scalar(
@ -117,7 +107,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
class BaseApiKeyResource(Resource):
@ -127,20 +117,9 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None
resource_id_field: str | None = None
def delete(
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
) -> tuple[str, int]:
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
return "", 204
def _delete_api_key(
self,
resource_id: str,
api_key_id: str,
current_tenant_id: str,
current_user: Account,
) -> None:
def delete(self, resource_id: str, api_key_id: str):
assert self.resource_id_field is not None, "resource_id_field must be set"
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
@ -167,6 +146,8 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return "", 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource):
@ -174,21 +155,18 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
return super().get(resource_id)
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
return super().post(resource_id)
resource_type = ApiTokenType.APP
resource_model = App
@ -202,14 +180,9 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.APP
resource_model = App
@ -222,21 +195,18 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
return super().get(resource_id)
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@with_current_tenant_id
@edit_permission_required
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
return super().post(resource_id)
resource_type = ApiTokenType.DATASET
resource_model = Dataset
@ -250,14 +220,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
) -> tuple[str, int]:
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
return "", 204
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.DATASET
resource_model = Dataset

View File

@ -159,15 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id: UUID, annotation_setting_id: UUID):
annotation_setting_id_str = str(annotation_setting_id)
def post(self, app_id: UUID, annotation_setting_id):
annotation_setting_id = str(annotation_setting_id)
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(
str(app_id), annotation_setting_id_str, setting_args
)
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
return result, 200
@ -183,9 +181,9 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id: UUID, job_id: UUID, action: str):
job_id_str = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
def get(self, app_id: UUID, job_id, action):
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
@ -193,10 +191,10 @@ class AnnotationReplyActionStatusApi(Resource):
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@console_ns.route("/apps/<uuid:app_id>/annotations")

View File

@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model, with_session
from controllers.console.app.wraps import get_app_model
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
@ -26,6 +26,7 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -851,11 +852,11 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_session
@get_app_model
def get(self, session: Session, app_model: App):
def get(self, app_model):
"""Get app trace"""
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
return app_trace_config

View File

@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
def post(self, import_id):
# Check user role first
current_user, _ = current_account_with_tenant()

View File

@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id: str):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id: str):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
import sqlalchemy as sa
from flask import abort, request
@ -134,7 +133,7 @@ class CompletionConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.group_by(Conversation.id)
.distinct()
)
elif args.annotation_status == "not_annotated":
query = (
@ -165,10 +164,10 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
conversation_id_str = str(conversation_id)
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -182,12 +181,12 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
def delete(self, app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
conversation_id = str(conversation_id)
try:
ConversationService.delete(app_model, conversation_id_str, current_user)
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -272,7 +271,7 @@ class ChatConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.group_by(Conversation.id)
.distinct()
)
case "not_annotated":
query = (
@ -318,10 +317,10 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
conversation_id_str = str(conversation_id)
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -335,12 +334,12 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
def delete(self, app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
conversation_id = str(conversation_id)
try:
ConversationService.delete(app_model, conversation_id_str, current_user)
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,7 +1,6 @@
import json
from datetime import datetime
from typing import Any
from uuid import UUID
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
@ -163,7 +162,7 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, server_id: UUID):
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = db.session.scalar(
select(AppMCPServer)

View File

@ -1,7 +1,6 @@
import logging
from datetime import datetime
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -337,13 +336,13 @@ class MessageSuggestedQuestionApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id: UUID):
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id_str = str(message_id)
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, message_id=message_id_str, user=current_user, invoke_from=InvokeFrom.DEBUGGER
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
@ -417,11 +416,11 @@ class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model, message_id: UUID):
message_id_str = str(message_id)
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id_str, Message.app_id == app_model.id).limit(1)
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:

View File

@ -2,7 +2,6 @@ import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -346,15 +345,14 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: UUID):
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
return variable
@ -365,7 +363,7 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: UUID):
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
# Local File:
@ -392,11 +390,10 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
new_name = args_model.name
@ -437,15 +434,14 @@ class VariableApi(Resource):
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: UUID):
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
@ -461,7 +457,7 @@ class VariableResetApi(Resource):
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: UUID):
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -472,11 +468,10 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)

View File

@ -1,415 +0,0 @@
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
with a producer-organized view of each node's declared outputs and their
per-run status. This module exposes two parallel sets of three read-only
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
schedule / plugin triggers). Both sets share the same service code, the same
response shapes, and the same error codes; the URL is the *only* difference,
so the frontend can pick the right prefix based on which run-detail page the
user is on.
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
``published_run_inspector_not_implemented`` 404 code is therefore no longer
produced.
URLs follow the design doc and reuse the existing
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
:mod:`controllers.console.app.workflow_draft_variable`. The
``published`` prefix mirrors it shape-for-shape.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from uuid import UUID
from flask import Response
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.exception import BaseHTTPException
from libs.login import login_required
from models import App, AppMode
from services.workflow import inspector_events
from services.workflow.node_output_inspector_service import (
NodeOutputInspectorError,
NodeOutputInspectorService,
)
logger = logging.getLogger(__name__)
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
# intervening proxies (nginx, ingress) don't reap the idle connection.
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
_HEARTBEAT_EVERY_TICKS = 15
# Hard ceiling on a single stream — if we never see a terminal workflow
# event (engine crashed, redis dropped the message), force-close after this
# many ticks (= seconds).
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
def _service() -> NodeOutputInspectorService:
"""One-line factory so tests can monkeypatch a stub if needed."""
return NodeOutputInspectorService()
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
"""Resource-body shared by draft + published snapshot endpoints.
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
behaviour lives here, where unit tests can hit it without spinning up
Flask request context.
"""
try:
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return snapshot.model_dump(mode="json")
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
"""Resource-body shared by draft + published node-detail endpoints."""
try:
view = _service().node_detail(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return view.model_dump(mode="json")
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
"""Resource-body shared by draft + published output-preview endpoints."""
try:
preview = _service().output_preview(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
output_name=output_name,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return preview.model_dump(mode="json")
class _InspectorNotFound(BaseHTTPException):
"""404 that preserves the inspector's specific error code.
Without this the response body collapses to a generic ``not_found`` code
and clients lose the ability to distinguish, e.g.,
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
"""
code = 404
def __init__(self, error: NodeOutputInspectorError) -> None:
self.error_code = error.code
super().__init__(description=str(error))
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
class WorkflowDraftRunNodeOutputsApi(Resource):
"""Whole-run snapshot organized by producer node."""
@console_ns.doc("get_workflow_draft_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowDraftRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status."""
@console_ns.doc("get_workflow_draft_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output (with signed URL for file refs)."""
@console_ns.doc("get_workflow_draft_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
# ──────────────────────────────────────────────────────────────────────────────
# SSE event stream — shared generator used by draft + published variants
# ──────────────────────────────────────────────────────────────────────────────
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
``data`` is JSON-serialized when given as a dict; raw strings are
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
"""Yield SSE-framed strings for one workflow run.
The stream begins with a full ``snapshot`` event so the client has a
starting state without needing a separate REST GET. Then for every
``node_changed`` message from the pub/sub channel we re-read that node
from DB and push a fresh ``node_changed`` event. When the workflow run
reaches a terminal state we push one final ``workflow_run_completed``
event and close the stream.
Failures inside the loop are caught and surfaced as ``error`` events so
the frontend can show a banner rather than seeing the connection drop
silently. The Inspector never raises across the SSE boundary.
"""
service = _service()
run_id_str = str(run_id)
# Initial snapshot — also flushes a 404 back at the client right away
# if the run is gone (raised before yielding any bytes, so Flask turns it
# into the normal HTTP 404 path).
try:
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
event_id = 0
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
# If the run already finished by the time the client connected, emit
# the terminal envelope synchronously and close — no point subscribing.
# The enum value for partial success is the hyphenated ``partial-succeeded``
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
event_id,
)
return
# Live subscription
ticks_since_heartbeat = 0
total_ticks = 0
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
total_ticks += 1
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
logger.warning(
"Inspector SSE: forcing close after %ds without terminal event for run %s",
_STREAM_HARD_TIMEOUT_TICKS,
run_id_str,
)
return
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
# ``node_changed`` message with both fields ``None`` on every redis
# timeout. Real ``workflow_completed`` messages keep their kind even
# when status couldn't be resolved (publisher race), so checking kind
# first makes the heartbeat branch safe.
if message.kind == "node_changed" and message.node_id is None and message.status is None:
ticks_since_heartbeat += 1
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
yield ":keepalive\n\n"
ticks_since_heartbeat = 0
continue
ticks_since_heartbeat = 0
if message.kind == "workflow_completed":
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
event_id,
)
return
# node_changed: recompute the node slice from DB
if not message.node_id:
continue
try:
node_view = service.node_detail(
app_model=app_model,
workflow_run_id=run_id_str,
node_id=message.node_id,
)
except NodeOutputInspectorError:
# Node may not appear in the graph yet (race with persistence); skip.
continue
except Exception:
logger.warning(
"Inspector SSE: node_detail failed for run %s node %s",
run_id_str,
message.node_id,
exc_info=True,
)
event_id += 1
yield _sse_envelope(
"error",
{"node_id": message.node_id, "message": "failed to refresh node detail"},
event_id,
)
continue
event_id += 1
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
class WorkflowDraftRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a draft run."""
@console_ns.doc("stream_workflow_draft_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
# ──────────────────────────────────────────────────────────────────────────────
# Published-run endpoints — symmetric to the draft trio above
# ──────────────────────────────────────────────────────────────────────────────
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
class WorkflowPublishedRunNodeOutputsApi(Resource):
"""Whole-run snapshot for a *published* workflow run.
Same response shape as the ``/draft/`` variant — frontend can multiplex
based on which page (Composer test-run vs. Run History) is mounted.
"""
@console_ns.doc("get_workflow_published_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status (published run)."""
@console_ns.doc("get_workflow_published_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
"/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output of a published run."""
@console_ns.doc("get_workflow_published_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output of a published run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a published run."""
@console_ns.doc("stream_workflow_published_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -1,6 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -189,7 +188,7 @@ class WorkflowRunExportApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id: str):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)
@ -368,14 +367,14 @@ class WorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id):
"""
Get workflow run detail
"""
run_id_str = str(run_id)
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id_str)
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
if workflow_run is None:
raise NotFoundError("Workflow run not found")
@ -397,17 +396,17 @@ class WorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id):
"""
Get workflow run node execution list
"""
run_id_str = str(run_id)
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
user = cast("Account | EndUser", current_user)
node_executions = workflow_run_service.get_workflow_run_node_executions(
app_model=app_model,
run_id=run_id_str,
run_id=run_id,
user=user,
)

View File

@ -1,38 +1,16 @@
"""Controller decorators for console app resources.
`with_session` opens one SQLAlchemy session for a request handler and injects it
as the first argument after `self`. Handlers use a transaction by default so
migrated write paths keep commit/rollback handling; pure read handlers may opt
out with `write=False`. App-loading decorators prefer that injected session when
present, while still supporting existing handlers that have not been migrated
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
"""
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, cast, overload
from typing import overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.error import AppNotFoundError
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
def _load_app_model(session: Session, app_id: str) -> App | None:
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
_, current_tenant_id = current_account_with_tenant()
app_model = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
@ -45,63 +23,6 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R],
*,
write: bool = True,
) -> Callable[Concatenate[T, P], R]: ...
@overload
def with_session[T, **P, R](
view: None = None,
*,
write: bool = True,
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R] | None = None,
*,
write: bool = True,
) -> (
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
):
"""Inject a request-scoped session, using a transaction only for write handlers."""
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if write:
with session_factory.get_session_maker().begin() as session:
return view(self, session, *args, **kwargs)
with session_factory.create_session() as session:
return view(self, session, *args, **kwargs)
return wrapper
if view is None:
return decorator
return decorator(view)
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
if len(args) < 2:
return None
candidate = args[1]
if isinstance(candidate, Session):
return candidate
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
return cast(Session, candidate)
return None
@overload
def get_app_model[**P, R](
view: Callable[P, R],
@ -123,13 +44,6 @@ def get_app_model[**P, R](
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Inject the App model for handlers that receive an `app_id` path parameter.
New handlers may compose `@with_session` above this decorator so the app row
is loaded through the same request-scoped session used by the controller.
Existing handlers continue to work through `db.session` until migrated.
"""
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@ -141,11 +55,7 @@ def get_app_model[**P, R](
del kwargs["app_id"]
session = _get_injected_session(args)
if session is None:
app_model = _load_app_model_from_scoped_session(app_id)
else:
app_model = _load_app_model(session, app_id)
app_model = _load_app_model(app_id)
if not app_model:
raise AppNotFoundError()

View File

@ -1,16 +1,14 @@
from uuid import UUID
from flask_restx import Resource
from pydantic import BaseModel, Field
from controllers.common.schema import register_response_schema_models, register_schema_models
from fields.base import ResponseModel
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from .. import console_ns
from ..auth.error import ApiKeyAuthFailedError
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
class ApiKeyAuthBindingPayload(BaseModel):
@ -42,8 +40,8 @@ class ApiKeyAuthDataSource(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
if data_source_api_key_bindings:
return {
@ -69,9 +67,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
@ -89,9 +87,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.response(204, "Binding deleted successfully")
@with_current_tenant_id
def delete(self, current_tenant_id: str, binding_id: UUID):
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return "", 204

View File

@ -1,5 +1,4 @@
import logging
from uuid import UUID
import httpx
from flask import current_app, redirect, request
@ -159,15 +158,16 @@ class OAuthDataSourceSync(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str, binding_id: UUID):
binding_id_str = str(binding_id)
def get(self, provider, binding_id):
provider = str(provider)
binding_id = str(binding_id)
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id_str)
oauth_provider.sync_data_source(binding_id)
except httpx.HTTPStatusError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text

View File

@ -8,9 +8,9 @@ from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
from controllers.console.wraps import account_initialization_required, setup_required
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
@ -133,10 +133,12 @@ class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
user_account_id = current_user.id
def post(self, oauth_provider_app: OAuthProviderApp):
current_user, _ = current_account_with_tenant()
account = current_user
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{

View File

@ -1,7 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource, fields, marshal_with
@ -48,6 +47,7 @@ class NotionEstimatePayload(BaseModel):
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
class DataSourceNotionPreviewQuery(BaseModel):
@ -204,6 +204,9 @@ class DataSourceNotionListApi(Resource):
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters = query.datasource_parameters or {}
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -251,7 +254,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
datasource_parameters=datasource_parameters,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
@ -290,7 +293,7 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, page_id: UUID, page_type: str):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
@ -303,11 +306,11 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
page_id_str = str(page_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id="",
notion_obj_id=page_id_str,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_tenant_id,
@ -364,7 +367,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -382,7 +385,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -1,17 +1,15 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
@ -32,10 +30,26 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import build_icon_url, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
@ -47,6 +61,58 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
@ -142,165 +208,9 @@ class ConsoleDatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetListItemResponse(DatasetDetailResponse):
partial_member_list: list[str]
class DatasetListResponse(ResponseModel):
data: list[DatasetListItemResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
class DatasetQueryFileInfoResponse(ResponseModel):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetQueryContentResponse(ResponseModel):
content_type: str
content: str
file_info: DatasetQueryFileInfoResponse | None = None
class DatasetQueryDetailResponse(ResponseModel):
id: str
queries: list[DatasetQueryContentResponse]
source: str
source_app_id: str | None
created_by_role: str
created_by: str
created_at: int
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DatasetQueryListResponse(ResponseModel):
data: list[DatasetQueryDetailResponse]
has_more: bool
limit: int
total: int
page: int
class RelatedAppResponse(ResponseModel):
id: str
name: str
description: str
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None
icon: str | None
icon_background: str | None
icon_url: str | None = None
@model_validator(mode="after")
def _set_icon_url(self) -> "RelatedAppResponse":
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
return self
class RelatedAppListResponse(ResponseModel):
data: list[RelatedAppResponse]
total: int
class DocumentStatusResponse(ResponseModel):
id: str
indexing_status: str
processing_started_at: int | None
parsing_completed_at: int | None
cleaning_completed_at: int | None
splitting_completed_at: int | None
completed_at: int | None
paused_at: int | None
error: str | None
stopped_at: int | None
completed_segments: int | None = None
total_segments: int | None = None
@field_validator(
"processing_started_at",
"parsing_completed_at",
"cleaning_completed_at",
"splitting_completed_at",
"completed_at",
"paused_at",
"stopped_at",
mode="before",
)
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentStatusListResponse(ResponseModel):
data: list[DocumentStatusResponse]
class ErrorDocsResponse(DocumentStatusListResponse):
total: int
class IndexingEstimatePreviewItemResponse(ResponseModel):
content: str
child_chunks: list[str] | None = None
summary: str | None = None
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
question: str
answer: str
class IndexingEstimateResponse(ResponseModel):
total_segments: int
preview: list[IndexingEstimatePreviewItemResponse]
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
class RetrievalSettingResponse(ResponseModel):
retrieval_method: list[str]
class PartialMemberListResponse(ResponseModel):
data: list[str]
class AutoDisableLogsResponse(ResponseModel):
document_ids: list[str]
count: int
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
register_response_schema_models(
console_ns,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetQueryListResponse,
IndexingEstimateResponse,
RelatedAppListResponse,
DocumentStatusListResponse,
ErrorDocsResponse,
RetrievalSettingResponse,
PartialMemberListResponse,
AutoDisableLogsResponse,
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -383,8 +293,17 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
class DatasetListApi(Resource):
@console_ns.doc("get_datasets")
@console_ns.doc(description="Get list of datasets")
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
@console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
"ids": "Filter by dataset IDs (list)",
"keyword": "Search keyword",
"tag_ids": "Filter by tag IDs (list)",
"include_all": "Include all datasets (default: false)",
}
)
@console_ns.response(200, "Datasets retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@ -423,7 +342,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
partial_members_map: dict[str, list[str]] = {}
if dataset_ids:
@ -460,12 +379,12 @@ class DatasetListApi(Resource):
"total": total,
"page": query.page,
}
return dump_response(DatasetListResponse, response), 200
return response, 200
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@ -494,7 +413,7 @@ class DatasetListApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return dump_response(DatasetDetailResponse, dataset), 201
return marshal(dataset, dataset_detail_fields), 201
@console_ns.route("/datasets/<uuid:dataset_id>")
@ -502,17 +421,13 @@ class DatasetApi(Resource):
@console_ns.doc("get_dataset")
@console_ns.doc(description="Get dataset details")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Dataset retrieved successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -522,7 +437,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = dump_response(DatasetDetailResponse, dataset)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -555,18 +470,14 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(
200,
"Dataset updated successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID):
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -595,7 +506,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = dump_response(DatasetDetailResponse, dataset)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
@ -614,7 +525,7 @@ class DatasetApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Dataset deleted successfully")
def delete(self, dataset_id: UUID):
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
@ -644,7 +555,7 @@ class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
@ -656,15 +567,11 @@ class DatasetQueryApi(Resource):
@console_ns.doc("get_dataset_queries")
@console_ns.doc(description="Get dataset query history")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Query history retrieved successfully",
console_ns.models[DatasetQueryListResponse.__name__],
)
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -682,24 +589,20 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
"data": dataset_queries,
"data": marshal(dataset_queries, dataset_query_detail_model),
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
}
return dump_response(DatasetQueryListResponse, response), 200
return response, 200
@console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource):
@console_ns.doc("estimate_dataset_indexing")
@console_ns.doc(description="Estimate dataset indexing cost")
@console_ns.response(
200,
"Indexing estimate calculated successfully",
console_ns.models[IndexingEstimateResponse.__name__],
)
@console_ns.response(200, "Indexing estimate calculated successfully")
@setup_required
@login_required
@account_initialization_required
@ -796,15 +699,12 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.doc("get_dataset_related_apps")
@console_ns.doc(description="Get applications related to dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Related apps retrieved successfully",
console_ns.models[RelatedAppListResponse.__name__],
)
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
@marshal_with(related_app_list_model)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -824,7 +724,7 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
return {"data": related_apps, "total": len(related_apps)}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
@ -832,19 +732,15 @@ class DatasetIndexingStatusApi(Resource):
@console_ns.doc("get_dataset_indexing_status")
@console_ns.doc(description="Get dataset indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Indexing status retrieved successfully",
console_ns.models[DocumentStatusListResponse.__name__],
)
@console_ns.response(200, "Indexing status retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@ -882,8 +778,9 @@ class DatasetIndexingStatusApi(Resource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data, 200
@console_ns.route("/datasets/api-keys")
@ -952,15 +849,15 @@ class DatasetApiDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, api_key_id: UUID):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id_str = str(api_key_id)
api_key_id = str(api_key_id)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id_str,
ApiToken.id == api_key_id,
)
.limit(1)
)
@ -985,7 +882,7 @@ class DatasetEnableApiApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, dataset_id: UUID, status: str):
def post(self, dataset_id, status):
dataset_id_str = str(dataset_id)
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
@ -1010,18 +907,13 @@ class DatasetApiBaseUrlApi(Resource):
class DatasetRetrievalSettingApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting")
@console_ns.doc(description="Get dataset retrieval settings")
@console_ns.response(
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
)
@console_ns.response(200, "Retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
)
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -1029,19 +921,12 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting_mock")
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
@console_ns.doc(params={"vector_type": "Vector store type"})
@console_ns.response(
200,
"Mock retrieval settings retrieved successfully",
console_ns.models[RetrievalSettingResponse.__name__],
)
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type: str):
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
)
def get(self, vector_type):
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@ -1049,19 +934,19 @@ class DatasetErrorDocs(Resource):
@console_ns.doc("get_dataset_error_docs")
@console_ns.doc(description="Get dataset error documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
@console_ns.response(200, "Error documents retrieved successfully")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
@ -1069,17 +954,13 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.doc("get_dataset_permission_users")
@console_ns.doc(description="Get dataset permission user list")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Permission users retrieved successfully",
console_ns.models[PartialMemberListResponse.__name__],
)
@console_ns.response(200, "Permission users retrieved successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -1092,7 +973,9 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
return {
"data": partial_members_list,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
@ -1100,18 +983,14 @@ class DatasetAutoDisableLogApi(Resource):
@console_ns.doc("get_dataset_auto_disable_logs")
@console_ns.doc(description="Get dataset auto disable logs")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Auto disable logs retrieved successfully",
console_ns.models[AutoDisableLogsResponse.__name__],
)
@console_ns.response(200, "Auto disable logs retrieved successfully")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
@ -316,9 +315,9 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
@ -343,7 +342,7 @@ class DatasetDocumentListApi(Resource):
)
except (ArgumentTypeError, ValueError, Exception):
fetch = False
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -352,7 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
@ -373,7 +372,7 @@ class DatasetDocumentListApi(Resource):
sa.select(
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
)
.where(DocumentSegment.dataset_id == dataset_id_str)
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)
@ -445,11 +444,11 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -473,7 +472,7 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -491,9 +490,9 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Documents deleted successfully")
def delete(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
@ -583,11 +582,11 @@ class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
@ -625,7 +624,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
data_process_rule_dict,
document.doc_form,
"English",
dataset_id_str,
dataset_id,
)
return estimate_response.model_dump(), 200
except LLMBadRequestError:
@ -648,10 +647,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
@ -725,7 +725,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict,
document.doc_form,
"English",
dataset_id_str,
dataset_id,
)
return response.model_dump(), 200
except LLMBadRequestError:
@ -745,9 +745,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
def get(self, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
completed_segments = (
@ -799,16 +800,16 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -817,7 +818,7 @@ class DocumentIndexingStatusApi(DocumentResource):
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -860,10 +861,10 @@ class DocumentApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
@ -872,7 +873,7 @@ class DocumentApi(DocumentResource):
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
response = {
"id": document.id,
@ -906,7 +907,7 @@ class DocumentApi(DocumentResource):
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
response = {
"id": document.id,
@ -949,16 +950,16 @@ class DocumentApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
@ -979,7 +980,7 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@ -996,16 +997,16 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id: str):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
dataset_id=dataset_id_str,
dataset_id=dataset_id,
document_ids=document_ids,
tenant_id=current_tenant_id,
current_user=current_user,
@ -1043,11 +1044,11 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1091,11 +1092,11 @@ class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id: UUID, document_id: UUID):
def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
@ -1140,10 +1141,10 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -1178,16 +1179,16 @@ class DocumentPauseApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document paused successfully")
def patch(self, dataset_id: UUID, document_id: UUID):
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
@ -1213,14 +1214,14 @@ class DocumentRecoverApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document resumed successfully")
def patch(self, dataset_id: UUID, document_id: UUID):
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
@ -1246,11 +1247,11 @@ class DocumentRetryApi(DocumentResource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
@console_ns.response(204, "Documents retry started successfully")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
"""retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
@ -1276,7 +1277,7 @@ class DocumentRetryApi(DocumentResource):
logger.exception("Failed to retry document, document id: %s", document_id)
continue
# retry document
DocumentService.retry_document(dataset_id_str, retry_documents)
DocumentService.retry_document(dataset_id, retry_documents)
return "", 204
@ -1288,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
@ -1300,7 +1301,7 @@ class DocumentRenameApi(DocumentResource):
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
try:
document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name)
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
@ -1313,15 +1314,15 @@ class WebsiteDocumentSyncApi(DocumentResource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_tenant_id:
@ -1332,7 +1333,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id_str, document)
DocumentService.sync_website_document(dataset_id, document)
return {"result": "success"}, 200
@ -1342,19 +1343,19 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id_str)
.where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.limit(1)
)
@ -1391,7 +1392,7 @@ class DocumentGenerateSummaryApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
"""
Generate summary index for specified documents.
@ -1400,10 +1401,10 @@ class DocumentGenerateSummaryApi(Resource):
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -1437,7 +1438,7 @@ class DocumentGenerateSummaryApi(Resource):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list)
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
@ -1451,7 +1452,7 @@ class DocumentGenerateSummaryApi(Resource):
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id_str,
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
@ -1464,11 +1465,11 @@ class DocumentGenerateSummaryApi(Resource):
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id_str, document.id)
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id_str,
dataset_id,
)
return {"result": "success"}, 200
@ -1484,7 +1485,7 @@ class DocumentSummaryStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
@ -1498,11 +1499,11 @@ class DocumentSummaryStatusApi(DocumentResource):
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -1516,8 +1517,8 @@ class DocumentSummaryStatusApi(DocumentResource):
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id_str,
dataset_id=dataset_id_str,
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200

View File

@ -1,6 +1,4 @@
import uuid
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource, marshal
@ -115,12 +113,12 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -129,7 +127,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
@ -150,7 +148,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = (
select(DocumentSegment)
.where(
DocumentSegment.document_id == document_id_str,
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
@ -203,9 +201,7 @@ class DatasetDocumentSegmentListApi(Resource):
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
@ -230,19 +226,19 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
@ -266,15 +262,15 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting
@ -325,17 +321,17 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if not current_user.is_dataset_editor:
@ -365,7 +361,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -376,19 +372,19 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -408,10 +404,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -432,33 +428,33 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -487,17 +483,17 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
@ -521,8 +517,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
batch_create_segment_to_index_task.delay(
str(job_id),
upload_file_id,
dataset_id_str,
document_id_str,
dataset_id,
document_id,
current_tenant_id,
current_user.id,
)
@ -534,7 +530,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, job_id=None, dataset_id: UUID | None = None, document_id: UUID | None = None):
def get(self, job_id=None, dataset_id=None, document_id=None):
if job_id is None:
raise NotFound("The job does not exist.")
job_id = str(job_id)
@ -555,24 +551,24 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -610,26 +606,26 @@ class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -646,9 +642,7 @@ class ChildChunkAddApi(Resource):
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
@ -662,26 +656,26 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -711,39 +705,39 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
ChildChunk.document_id == document_id,
)
.limit(1)
)
@ -768,39 +762,39 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
ChildChunk.document_id == document_id,
)
.limit(1)
)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
@ -10,12 +8,7 @@ from controllers.common.fields import UsageCountResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.dataset_fields import (
dataset_detail_fields,
dataset_retrieval_model_fields,
@ -131,9 +124,9 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.response(200, "External API templates retrieved successfully")
@setup_required
@login_required
@with_current_tenant_id
@account_initialization_required
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
@ -182,11 +175,11 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id_str, current_tenant_id
external_knowledge_api_id, current_tenant_id
)
if external_knowledge_api is None:
raise NotFound("API template not found.")
@ -197,9 +190,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id: UUID):
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
@ -207,7 +200,7 @@ class ExternalApiTemplateApi(Resource):
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id_str,
external_knowledge_api_id=external_knowledge_api_id,
args=payload.model_dump(),
)
@ -217,14 +210,14 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(204, "External knowledge API deleted successfully")
def delete(self, external_knowledge_api_id: UUID):
def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id_str)
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return "", 204
@ -237,12 +230,12 @@ class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id_str, current_tenant_id
external_knowledge_api_id, current_tenant_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
@ -293,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from flask_restx import Resource
from pydantic import Field, field_validator
@ -119,7 +118,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask_restx import Resource
from werkzeug.exceptions import NotFound
@ -43,7 +42,7 @@ class DatasetMetadataCreateApi(Resource):
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
@ -63,7 +62,7 @@ class DatasetMetadataCreateApi(Resource):
@console_ns.response(
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -80,7 +79,7 @@ class DatasetMetadataApi(Resource):
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id: UUID, metadata_id: UUID):
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -100,7 +99,7 @@ class DatasetMetadataApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
def delete(self, dataset_id: UUID, metadata_id: UUID):
def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -137,7 +136,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -165,7 +164,7 @@ class DocumentMetadataEditApi(Resource):
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
import services
@ -54,13 +54,12 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
session.commit()
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_tenant_id,

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
@ -169,22 +168,21 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: UUID):
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: UUID):
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
# Local File:
@ -212,12 +210,11 @@ class RagPipelineVariableApi(Resource):
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
@ -253,16 +250,15 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: UUID):
def delete(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -271,7 +267,7 @@ class RagPipelineVariableApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: UUID):
def put(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -282,12 +278,11 @@ class RagPipelineVariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()

View File

@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@ -67,12 +67,10 @@ class RagPipelineImportApi(Resource):
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
# (which return FAILED status instead of re-raising) do not leave the
# transaction in a closed state that a .begin() context manager cannot
# handle. See app_import.py for the canonical pattern.
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
result = import_service.import_rag_pipeline(
account=account,
@ -82,10 +80,6 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
status = result.status
@ -105,17 +99,15 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
def post(self, import_id: str):
def post(self, import_id):
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -132,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@ -150,7 +142,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@ -1,7 +1,6 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource
@ -876,14 +875,14 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: UUID):
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
"""
run_id_str = str(run_id)
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id_str)
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id)
if workflow_run is None:
raise NotFound("Workflow run not found")
@ -901,17 +900,17 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: UUID):
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
"""
run_id_str = str(run_id)
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
user = cast("Account | EndUser", current_user)
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline,
run_id=run_id_str,
run_id=run_id,
user=user,
)
@ -961,15 +960,15 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
def post(self, dataset_id: str):
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
rag_pipeline_transform_service = RagPipelineTransformService()
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
return result

View File

@ -20,7 +20,6 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from graphon.model_runtime.errors.invoke import InvokeError
from models.model import InstalledApp
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -41,10 +40,8 @@ register_schema_model(console_ns, TextToAudioPayload)
endpoint="installed_app_audio",
)
class ChatAudioApi(InstalledAppResource):
def post(self, installed_app: InstalledApp):
def post(self, installed_app):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
file = request.files["file"]
@ -84,10 +81,8 @@ class ChatAudioApi(InstalledAppResource):
)
class ChatTextApi(InstalledAppResource):
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
def post(self, installed_app: InstalledApp):
def post(self, installed_app):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
try:
payload = TextToAudioPayload.model_validate(console_ns.payload or {})

View File

@ -31,7 +31,7 @@ from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import AppMode, InstalledApp
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
@ -83,10 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse)
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
def post(self, installed_app: InstalledApp):
def post(self, installed_app):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
@ -135,10 +133,8 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app: InstalledApp, task_id: str):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
@ -161,10 +157,8 @@ class CompletionStopApi(InstalledAppResource):
)
class ChatApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
def post(self, installed_app: InstalledApp):
def post(self, installed_app):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -215,10 +209,8 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app: InstalledApp, task_id: str):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
@ -8,7 +7,6 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
@ -21,7 +19,7 @@ from fields.conversation_fields import (
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
from models.model import AppMode, InstalledApp
from models.model import AppMode
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService
@ -45,10 +43,8 @@ register_response_schema_models(console_ns, ResultResponse)
)
class ConversationListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app: InstalledApp):
def get(self, installed_app):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -95,10 +91,8 @@ class ConversationListApi(InstalledAppResource):
)
class ConversationApi(InstalledAppResource):
@console_ns.response(204, "Conversation deleted successfully")
def delete(self, installed_app: InstalledApp, c_id: UUID):
def delete(self, installed_app, c_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -120,10 +114,8 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app: InstalledApp, c_id: UUID):
def post(self, installed_app, c_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -153,10 +145,8 @@ class ConversationRenameApi(InstalledAppResource):
)
class ConversationPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app: InstalledApp, c_id: UUID):
def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -179,10 +169,8 @@ class ConversationPinApi(InstalledAppResource):
)
class ConversationUnPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app: InstalledApp, c_id: UUID):
def patch(self, installed_app, c_id):
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource):
"""
@console_ns.response(204, "App uninstalled successfully")
def delete(self, installed_app: InstalledApp):
def delete(self, installed_app):
_, current_tenant_id = current_account_with_tenant()
if installed_app.app_owner_tenant_id == current_tenant_id:
raise BadRequest("You can't uninstall an app owned by the current tenant")
@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource):
return "", 204
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
def patch(self, installed_app: InstalledApp):
def patch(self, installed_app):
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
commit_args = False

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from pydantic import BaseModel, TypeAdapter
@ -10,7 +9,6 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
AppUnavailableError,
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
@ -22,16 +20,15 @@ from controllers.console.explore.error import (
NotCompletionAppError,
)
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from models import Account
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode, InstalledApp
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
@ -61,11 +58,9 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
)
class MessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp):
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -100,20 +95,18 @@ class MessageListApi(InstalledAppResource):
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
message_id_str = str(message_id)
message_id = str(message_id)
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id_str,
message_id=message_id,
user=current_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
@ -130,15 +123,13 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
message_id_str = str(message_id)
message_id = str(message_id)
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
@ -148,7 +139,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id_str,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming,
)
@ -178,20 +169,18 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id_str = str(message_id)
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
@ -7,14 +5,11 @@ from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from models import Account
from models.model import InstalledApp
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -25,11 +20,9 @@ register_response_schema_models(console_ns, ResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
@with_current_user
def get(self, current_user: Account, installed_app: InstalledApp):
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -51,11 +44,9 @@ class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -74,17 +65,15 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
@console_ns.response(204, "Saved message deleted successfully")
@with_current_user
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
message_id_str = str(message_id)
message_id = str(message_id)
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id_str)
SavedMessageService.delete(app_model, current_user, message_id)
return "", 204

View File

@ -13,7 +13,6 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -26,7 +25,7 @@ from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from models import Account
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -42,11 +41,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
def post(self, installed_app: InstalledApp):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()

View File

@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -9,14 +8,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
from .wraps import account_initialization_required, setup_required
class CodeBasedExtensionQuery(BaseModel):
@ -116,11 +115,11 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
_, tenant_id = current_account_with_tenant()
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
]
@console_ns.doc("create_api_based_extension")
@ -130,9 +129,9 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str):
def post(self):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
@ -153,12 +152,12 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str, id: UUID):
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@ -169,9 +168,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def post(self, current_tenant_id: str, id: UUID):
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
@ -197,9 +196,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_tenant_id
def delete(self, current_tenant_id: str, id: UUID):
def delete(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)

View File

@ -2,13 +2,13 @@ from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_user, login_required
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
from .wraps import account_initialization_required, cloud_utm_record, setup_required
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
@console_ns.route("/features")
@ -24,34 +24,11 @@ class FeatureApi(Resource):
@login_required
@account_initialization_required
@cloud_utm_record
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
"""Get feature configuration for current tenant"""
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
).model_dump()
payload.pop("vector_space", None)
return payload
_, current_tenant_id = current_account_with_tenant()
@console_ns.route("/features/vector-space")
class FeatureVectorSpaceApi(Resource):
@console_ns.doc("get_tenant_feature_vector_space")
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
@console_ns.response(
200,
"Success",
console_ns.models[LimitationModel.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""Get vector-space usage and limit for current tenant"""
return FeatureService.get_vector_space(current_tenant_id).model_dump()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -22,13 +21,10 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
from . import console_ns
@ -65,8 +61,8 @@ class FileApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("documents")
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
@with_current_user
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -110,10 +106,10 @@ class FilePreviewApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
@with_current_tenant_id
def get(self, current_tenant_id: str, file_id: UUID):
file_id_str = str(file_id)
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
def get(self, file_id):
file_id = str(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
return {"content": text}

View File

@ -8,14 +8,8 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_user,
)
from libs.login import login_required
from models import Account
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
@ -76,10 +70,11 @@ class NotificationApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
@ -118,11 +113,11 @@ class NotificationDismissApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,

View File

@ -12,13 +12,11 @@ from controllers.common.errors import (
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService
@ -51,8 +49,7 @@ class RemoteFileUpload(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
@login_required
@with_current_user
def post(self, current_user: Account):
def post(self):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
@ -77,11 +74,12 @@ class RemoteFileUpload(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=current_user,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -9,16 +8,9 @@ from werkzeug.exceptions import Forbidden
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.base import ResponseModel
from libs.login import login_required
from models import Account
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
@ -99,8 +91,8 @@ class TagListApi(Resource):
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
@with_current_tenant_id
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
@ -116,9 +108,9 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
# Allow users with edit permission, or dataset editors (including dataset operators).
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -139,17 +131,17 @@ class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def patch(self, current_user: Account, tag_id: UUID):
tag_id_str = str(tag_id)
def patch(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -162,27 +154,28 @@ class TagUpdateDeleteApi(Resource):
@account_initialization_required
@edit_permission_required
@console_ns.response(204, "Tag deleted successfully")
def delete(self, tag_id: UUID):
tag_id_str = str(tag_id)
def delete(self, tag_id):
tag_id = str(tag_id)
TagService.delete_tag(tag_id_str)
TagService.delete_tag(tag_id)
return "", 204
def _require_tag_binding_edit_permission(current_user: Account) -> None:
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
@ -195,8 +188,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
@ -219,9 +212,8 @@ class TagBindingCollectionApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
return _create_tag_bindings(current_user)
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
@ -235,6 +227,5 @@ class TagBindingRemoveApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_current_user
def post(self, current_user: Account):
return _remove_tag_bindings(current_user)
def post(self):
return _remove_tag_bindings()

View File

@ -1,10 +1,8 @@
from urllib import parse
from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
import services
from configs import dify_config
@ -23,15 +21,15 @@ from controllers.console.auth.error import (
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@ -77,55 +75,7 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
if role != TenantAccountRole.DATASET_OPERATOR:
return True
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
return list(dict.fromkeys(email.lower() for email in emails))
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
new_member_count = 0
for email in emails:
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
new_member_count += 1
continue
exists = db.session.scalar(
select(TenantAccountJoin.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not exists:
new_member_count += 1
return new_member_count
def _count_current_members(tenant_id: str) -> int:
return (
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
)
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
if new_member_count <= 0:
return
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
if dify_config.ENTERPRISE_ENABLED:
workspace_members = features.workspace_members
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
raise WorkspaceMembersLimitExceeded()
return
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
members = features.members
current_member_count = _count_current_members(tenant_id)
if 0 < members.limit < current_member_count + new_member_count:
raise WorkspaceMembersLimitExceeded()
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
@console_ns.route("/workspaces/current/members")
@ -154,11 +104,12 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = _normalize_invitee_emails(args.emails)
invitee_emails = args.emails
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
@ -178,36 +129,37 @@ class MemberInviteEmailApi(Resource):
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
tenant_id = inviter.current_tenant.id
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
_check_member_invite_limits(tenant_id, new_member_count)
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
@ -223,7 +175,7 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id: UUID):
def delete(self, member_id):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -256,7 +208,7 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, member_id: UUID):
def put(self, member_id):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
@ -399,7 +351,7 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id: UUID):
def post(self, member_id):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)

View File

@ -532,7 +532,7 @@ class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type: str):
def get(self, model_type):
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -15,7 +15,6 @@ from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.plugin_service import PluginService
from fields.base import ResponseModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class ParserList(BaseModel):

View File

@ -166,10 +166,10 @@ class TenantListApi(Resource):
if tenant_plan:
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
else:
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
# Create a dictionary with tenant attributes

View File

@ -4,7 +4,6 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from flask import abort, request
from sqlalchemy import select
@ -17,7 +16,6 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models import Account
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
@ -84,7 +82,9 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.BILLING_ENABLED:
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -96,28 +96,21 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
if resource == "vector_space":
if not dify_config.BILLING_ENABLED:
return view(*args, **kwargs)
vector_space = FeatureService.get_vector_space(current_tenant_id)
if 0 < vector_space.limit <= vector_space.size:
abort(
403,
"The capacity of the knowledge storage space has reached the limit of your subscription.",
)
return view(*args, **kwargs)
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
members = features.members
apps = features.apps
vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit
if resource == "members" and 0 < members.limit <= members.size:
abort(403, "The number of members has reached the limit of your subscription.")
elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort(
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets
@ -147,7 +140,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
if resource == "add_segment":
if features.billing.subscription.plan == CloudPlan.SANDBOX:
@ -205,11 +198,15 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
utm_info = request.cookies.get("utm_info")
if dify_config.BILLING_ENABLED and utm_info:
_, current_tenant_id = current_account_with_tenant()
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -298,7 +295,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
features = FeatureService.get_features(current_tenant_id)
if features.knowledge_pipeline.publish_enabled:
return view(*args, **kwargs)
abort(403)
@ -312,6 +309,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
@ -329,6 +327,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
@ -496,25 +495,3 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return view(*args, **kwargs)
return decorated
def with_current_tenant_id[T, **P, R](
view: Callable[Concatenate[T, str, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
_, current_tenant_id = current_account_with_tenant()
return view(self, current_tenant_id, *args, **kwargs)
return decorated
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated

View File

@ -1,5 +1,4 @@
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -50,8 +49,8 @@ class ImagePreviewApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID):
file_id_str = str(file_id)
def get(self, file_id):
file_id = str(file_id)
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True))
timestamp = args.timestamp
@ -60,7 +59,7 @@ class ImagePreviewApi(Resource):
try:
generator, mimetype = FileService(db.engine).get_image_preview(
file_id=file_id_str,
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
@ -92,14 +91,14 @@ class FilePreviewApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID):
file_id_str = str(file_id)
def get(self, file_id):
file_id = str(file_id)
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True))
try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id_str,
file_id=file_id,
timestamp=args.timestamp,
nonce=args.nonce,
sign=args.sign,
@ -160,10 +159,10 @@ class WorkspaceWebappLogoApi(Resource):
415: "Unsupported file type",
}
)
def get(self, workspace_id: UUID):
workspace_id_str = str(workspace_id)
def get(self, workspace_id):
workspace_id = str(workspace_id)
custom_config = TenantService.get_custom_config(workspace_id_str)
custom_config = TenantService.get_custom_config(workspace_id)
webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
if not webapp_logo_file_id:

View File

@ -1,5 +1,4 @@
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -46,19 +45,17 @@ class ToolFileApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID, extension: str):
file_id_str = str(file_id)
def get(self, file_id, extension):
file_id = str(file_id)
args = ToolFileQuery.model_validate(request.args.to_dict())
if not verify_tool_file_signature(
file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign
):
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
raise Forbidden("Invalid request.")
try:
tool_file_manager = ToolFileManager()
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id_str,
file_id,
)
if not stream or not tool_file:

View File

@ -1,128 +0,0 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.device_flow_security import attach_anti_framing
from libs.external_api import ExternalApi
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
attach_anti_framing(bp)
api = ExternalApi(
bp,
version="1.0",
title="OpenAPI",
description="User-scoped programmatic API (bearer auth)",
)
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
# Register response/query models BEFORE importing controller modules so that
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.openapi._models import (
AccountPayload,
AccountResponse,
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppInfoResponse,
AppListQuery,
AppListResponse,
AppListRow,
AppRunRequest,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
MessageMetadata,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
RevokeResponse,
ServerVersionResponse,
SessionListResponse,
SessionRow,
TagItem,
UsageInfo,
WorkflowRunData,
WorkspaceDetailResponse,
WorkspaceListResponse,
WorkspacePayload,
WorkspaceSummaryResponse,
)
from fields.file_fields import FileResponse
register_schema_models(
openapi_ns,
AppDescribeQuery,
AppListQuery,
AppRunRequest,
DeviceCodeRequest,
DevicePollRequest,
DeviceLookupQuery,
DeviceMutateRequest,
PermittedExternalAppsListQuery,
)
register_response_schema_models(
openapi_ns,
TagItem,
UsageInfo,
MessageMetadata,
AppListRow,
AppListResponse,
AppInfoResponse,
AppDescribeInfo,
AppDescribeResponse,
WorkflowRunData,
AccountPayload,
WorkspacePayload,
AccountResponse,
SessionRow,
SessionListResponse,
PermittedExternalAppsListResponse,
RevokeResponse,
WorkspaceSummaryResponse,
WorkspaceListResponse,
WorkspaceDetailResponse,
DeviceCodeResponse,
DeviceLookupResponse,
DeviceMutateResponse,
FileResponse,
ServerVersionResponse,
)
from . import (
_meta,
account,
app_run,
apps,
apps_permitted_external,
files,
human_input_form,
index,
oauth_device,
oauth_device_sso,
workflow_events,
workspaces,
)
# Request models are imported from _models.py and registered above.
__all__ = [
"_meta",
"account",
"app_run",
"apps",
"apps_permitted_external",
"files",
"human_input_form",
"index",
"oauth_device",
"oauth_device_sso",
"workflow_events",
"workspaces",
]
api.add_namespace(openapi_ns)

View File

@ -1,66 +0,0 @@
"""Audit emission for openapi app-run endpoints.
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
matches the existing oauth_device convention. The EE OTel exporter consults
its own allowlist to decide whether to ship the line.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
def emit_app_run(
*,
app_id: str,
tenant_id: str,
caller_kind: str,
mode: str,
surface: str,
) -> None:
logger.info(
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
EVENT_APP_RUN_OPENAPI,
app_id,
tenant_id,
caller_kind,
mode,
surface,
extra={
"audit": True,
"event": EVENT_APP_RUN_OPENAPI,
"app_id": app_id,
"tenant_id": tenant_id,
"caller_kind": caller_kind,
"mode": mode,
"surface": surface,
},
)
def emit_wrong_surface(
*,
subject_type: str | None,
attempted_path: str,
client_id: str | None,
token_id: str | None,
) -> None:
logger.warning(
"audit: %s subject_type=%s attempted_path=%s",
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
subject_type,
attempted_path,
extra={
"audit": True,
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
"subject_type": subject_type,
"attempted_path": attempted_path,
"client_id": client_id,
"token_id": token_id,
},
)

View File

@ -1,143 +0,0 @@
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
from __future__ import annotations
from typing import Any, cast
from controllers.service_api.app.error import AppUnavailableError
from models import App
from models.model import AppMode
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": {},
"required": [],
}
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
def _file_object_shape() -> dict[str, Any]:
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
return {
"type": "object",
"properties": {
"type": {"type": "string"},
"transfer_method": {"type": "string"},
"url": {"type": "string"},
"upload_file_id": {"type": "string"},
},
"additionalProperties": True,
}
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
label = row.get("label") or row.get("variable", "")
base: dict[str, Any] = {"title": label} if label else {}
if row_type in ("text-input", "paragraph"):
out: dict[str, Any] = {"type": "string"} | base
max_length = row.get("max_length")
if isinstance(max_length, int) and max_length > 0:
out["maxLength"] = max_length
return out
if row_type == "select":
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
if row_type == "number":
return {"type": "number"} | base
if row_type == "file":
return _file_object_shape() | base
if row_type == "file-list":
return {
"type": "array",
"items": _file_object_shape(),
} | base
return None
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
"""Translate a user_input_form row list into (properties, required-list).
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
Unknown variable types are skipped (forward-compat).
"""
properties: dict[str, Any] = {}
required: list[str] = []
for row in form:
if not isinstance(row, dict) or len(row) != 1:
continue
((row_type, row_body),) = row.items()
if not isinstance(row_body, dict):
continue
variable = row_body.get("variable")
if not variable:
continue
schema = _row_to_schema(row_type, row_body)
if schema is None:
continue
properties[variable] = schema
if row_body.get("required"):
required.append(variable)
return properties, required
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
Raises `AppUnavailableError` on misconfigured apps.
"""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
return (
workflow.features_dict,
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
)
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
def build_input_schema(app: App) -> dict[str, Any]:
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
completion / workflow: `inputs` object only.
Raises `AppUnavailableError` on misconfigured apps.
"""
_, user_input_form = resolve_app_config(app)
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
properties: dict[str, Any] = {}
required: list[str] = []
if app.mode in _CHAT_FAMILY:
properties["query"] = {"type": "string", "minLength": 1}
required.append("query")
properties["inputs"] = {
"type": "object",
"properties": inputs_props,
"required": inputs_required,
"additionalProperties": False,
}
required.append("inputs")
return {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": properties,
"required": required,
}

View File

@ -1,23 +0,0 @@
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
Returns the server's project version and edition so the difyctl CLI can probe
compatibility without needing to be logged in. Mirrors the `_health` endpoint
in `index.py`.
"""
from flask_restx import Resource
from configs import dify_config
from controllers.openapi import openapi_ns
from controllers.openapi._models import ServerVersionResponse
@openapi_ns.route("/_version")
class VersionApi(Resource):
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
def get(self):
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
return ServerVersionResponse(
version=dify_config.project.version,
edition=edition,
).model_dump(mode="json")

View File

@ -1,344 +0,0 @@
"""Shared response substructures for openapi endpoints."""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from libs.helper import UUIDStrOrEmpty, uuid_value
from models.model import AppMode
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
MAX_PAGE_LIMIT = 200
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class MessageMetadata(BaseModel):
usage: UsageInfo | None = None
retriever_resources: list[dict[str, Any]] = []
class PaginationEnvelope[T](BaseModel):
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
page: int
limit: int
total: int
has_more: bool
data: list[T]
@classmethod
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
class TagItem(BaseModel):
name: str
class AppListRow(BaseModel):
id: str
name: str
description: str | None = None
mode: AppMode
tags: list[TagItem] = []
updated_at: str | None = None
created_by_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
class AppListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class PermittedExternalAppsListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class AppInfoResponse(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author: str | None = None
tags: list[TagItem] = []
class AppDescribeInfo(AppInfoResponse):
updated_at: str | None = None
service_api_enabled: bool
is_agent: bool = False
class AppDescribeResponse(BaseModel):
info: AppDescribeInfo | None = None
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
class ChatMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
conversation_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class CompletionMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class WorkflowRunData(BaseModel):
id: str
workflow_id: str
status: str
outputs: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
class WorkflowRunResponse(BaseModel):
workflow_run_id: str
task_id: str
mode: Literal["workflow"] = "workflow"
data: WorkflowRunData
class AccountPayload(BaseModel):
id: str
email: str
name: str
class WorkspacePayload(BaseModel):
id: str
name: str
role: str
class AccountResponse(BaseModel):
subject_type: str
subject_email: str | None = None
subject_issuer: str | None = None
account: AccountPayload | None = None
workspaces: list[WorkspacePayload] = []
default_workspace_id: str | None = None
class SessionRow(BaseModel):
id: str
prefix: str
client_id: str
device_label: str
created_at: str | None = None
last_used_at: str | None = None
expires_at: str | None = None
class SessionListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[SessionRow]
class RevokeResponse(BaseModel):
status: str
class WorkspaceSummaryResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
class WorkspaceListResponse(BaseModel):
workspaces: list[WorkspaceSummaryResponse]
class WorkspaceDetailResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
created_at: str | None = None
class DeviceCodeResponse(BaseModel):
device_code: str
user_code: str
verification_uri: str
expires_in: int
interval: int
class DeviceLookupResponse(BaseModel):
valid: bool
expires_in_remaining: int = 0
client_id: str | None = None
class DeviceMutateResponse(BaseModel):
status: str
class ServerVersionResponse(BaseModel):
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
version: str
edition: Literal["SELF_HOSTED", "CLOUD"]
class AppDescribeQuery(BaseModel):
"""`?fields=` allow-list for GET /apps/<id>/describe.
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
"""
model_config = ConfigDict(extra="forbid")
fields: set[str] | None = None
workspace_id: str | None = None
@field_validator("workspace_id", mode="before")
@classmethod
def _validate_workspace_id(cls, v: object) -> str | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("workspace_id must be a string")
try:
import uuid as _uuid
_uuid.UUID(v)
except ValueError:
raise ValueError("workspace_id must be a valid UUID")
return v
@field_validator("fields", mode="before")
@classmethod
def _parse_fields(cls, v: object) -> set[str] | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("fields must be a comma-separated string")
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
members = {m.strip() for m in v.split(",") if m.strip()}
unknown = members - _ALLOWED_DESCRIBE_FIELDS
if unknown:
raise ValueError(f"unknown field(s): {sorted(unknown)}")
return members
class AppListQuery(BaseModel):
"""mode is a closed enum."""
workspace_id: str
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
tag: str | None = Field(None, max_length=100)
class AppRunRequest(BaseModel):
inputs: dict[str, Any]
query: str | None = None
files: list[dict[str, Any]] | None = None
conversation_id: UUIDStrOrEmpty | None = None
auto_generate_name: bool = True
workflow_id: str | None = None
workspace_id: UUIDStrOrEmpty | None = None
@field_validator("conversation_id", mode="before")
@classmethod
def _normalize_conv(cls, value: str | None) -> str | None:
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
class DeviceCodeRequest(BaseModel):
client_id: str
device_label: str
class DevicePollRequest(BaseModel):
device_code: str
client_id: str
class DeviceLookupQuery(BaseModel):
user_code: str
class DeviceMutateRequest(BaseModel):
user_code: str
class PermittedExternalAppsListQuery(BaseModel):
"""Strict (extra='forbid')."""
model_config = ConfigDict(extra="forbid")
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
class ExtSubjectAssertionClaims(BaseModel):
email: str = _EMAIL_FIELD
issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
class ApprovalGrantClaimsPayload(BaseModel):
subject_email: str = _EMAIL_FIELD
subject_issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
csrf_token: str = Field(min_length=1, max_length=128)

View File

@ -1,169 +0,0 @@
from __future__ import annotations
from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
MAX_PAGE_LIMIT,
AccountPayload,
AccountResponse,
PaginationEnvelope,
RevokeResponse,
SessionListResponse,
SessionRow,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from services.account_service import AccountService, TenantService
from services.oauth_device_flow import (
list_active_sessions,
revoke_oauth_token,
token_belongs_to_subject,
)
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email,
subject_issuer=ctx.subject_issuer,
account=None,
workspaces=[],
default_workspace_id=None,
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email or (account.email if account else None),
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
).model_dump(mode="json")
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
all_rows = list_active_sessions(db.session, ctx, now)
total = len(all_rows)
sliced = all_rows[(page - 1) * limit : page * limit]
items = [
SessionRow(
id=str(r.id),
prefix=r.prefix,
client_id=r.client_id,
device_label=r.device_label,
created_at=_iso(r.created_at),
last_used_at=_iso(r.last_used_at),
expires_at=_iso(r.expires_at),
)
for r in sliced
]
return (
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
200,
)
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
# 404 (not 403) on cross-subject so the endpoint doesn't leak
# token IDs that belong to other subjects.
if not token_belongs_to_subject(db.session, session_id, ctx):
raise NotFound("session not found")
revoke_oauth_token(db.session, redis_client, session_id)
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt.isoformat().replace("+00:00", "Z")
def _pick_default_workspace(memberships) -> str | None:
if not memberships:
return None
for join, tenant in memberships:
if getattr(join, "current", False):
return str(tenant.id)
return str(memberships[0][1].id)
def _workspace_payload(row) -> WorkspacePayload:
join, tenant = row
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
def _account_payload(account) -> AccountPayload:
return AccountPayload(id=str(account.id), email=account.email, name=account.name)

View File

@ -1,165 +0,0 @@
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
from __future__ import annotations
import logging
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import AppRunRequest
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
IsDraftWorkflowError,
WorkflowIdFormatError,
WorkflowNotFoundError,
)
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@contextmanager
def _translate_service_errors() -> Iterator[None]:
try:
yield
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
return AppGenerateService.generate(
app_model=app,
user=caller,
args=args,
invoke_from=InvokeFrom.OPENAPI,
streaming=streaming,
)
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
if not payload.query or not payload.query.strip():
raise UnprocessableEntity("query_required_for_chat")
args = payload.model_dump(exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False
args.setdefault("query", "")
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
if payload.query is not None:
raise UnprocessableEntity("query_not_supported_for_workflow")
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
AppMode.CHAT: _run_chat,
AppMode.AGENT_CHAT: _run_chat,
AppMode.ADVANCED_CHAT: _run_chat,
AppMode.COMPLETION: _run_completion,
AppMode.WORKFLOW: _run_workflow,
}
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
handler = _DISPATCH.get(app_model.mode)
if handler is None:
raise UnprocessableEntity("mode_not_runnable")
try:
stream_obj = handler(app_model, caller, payload)
except HTTPException:
raise
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
emit_app_run(
app_id=app_model.id,
tenant_id=app_model.tenant_id,
caller_kind=caller_kind,
mode=str(app_model.mode),
surface="apps",
)
return helper.compact_generate_response(stream_obj)
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -1,270 +0,0 @@
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last → outermost → publishes the auth ContextVar before `require_scope`
reads it.
"""
from __future__ import annotations
import uuid as _uuid
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.common.fields import Parameters
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
from controllers.openapi._models import (
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppListQuery,
AppListResponse,
AppListRow,
TagItem,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
get_auth_ctx,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
_EMPTY_PARAMETERS: dict[str, Any] = {
"opening_statement": None,
"suggested_questions": [],
"user_input_form": [],
"file_upload": None,
"system_parameters": {},
}
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = get_auth_ctx()
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
except ValueError:
parsed_uuid = None
is_uuid = False
if is_uuid:
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None:
raise NotFound("app not found")
else:
if not workspace_id:
raise UnprocessableEntity("workspace_id is required for name-based lookup")
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
if len(matches) == 0:
raise NotFound("app not found")
if len(matches) > 1:
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
for m in matches:
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
raise Conflict("".join(lines))
app = matches[0]
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
def parameters_payload(app: App) -> dict:
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
features_dict, user_input_form = resolve_app_config(app)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
def get(self, app_id: str):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app, _ = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
want_params = requested is None or "parameters" in requested
want_schema = requested is None or "input_schema" in requested
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
tags=[TagItem(name=t.name) for t in app.tags],
author=app.author_name,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
is_agent=app.mode in ("agent-chat", "advanced-chat"),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return (
AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
).model_dump(mode="json", exclude_none=False),
200,
)
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
def get(self):
ctx: AuthContext = get_auth_ctx()
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
mode="json"
),
200,
)
if query.name:
try:
parsed_uuid = _uuid.UUID(query.name)
except ValueError:
parsed_uuid = None
else:
parsed_uuid = None
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None or str(app.tenant_id) != workspace_id:
return empty
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
item = AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[TagItem(name=t.name) for t in app.tags],
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=getattr(app, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
return env.model_dump(mode="json"), 200
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
params = AppListParams(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else "all", # type:ignore
name=query.name,
tag_ids=tag_ids,
status="normal",
# Visibility gate pushed into the query — pagination.total stays
# consistent across pages because invisible rows never count.
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
if pagination is None:
return empty
tenant_name = None
if pagination.items:
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
items = [
AppListRow(
id=str(r.id),
name=r.name,
description=r.description,
mode=r.mode,
tags=[TagItem(name=t.name) for t in r.tags],
updated_at=r.updated_at.isoformat() if r.updated_at else None,
created_by_name=getattr(r, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
for r in pagination.items
]
env = AppListResponse(
page=query.page,
limit=query.limit,
total=cast(int, pagination.total),
has_more=query.page * query.limit < cast(int, pagination.total),
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -1,102 +0,0 @@
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
(public / sso_verified). License-gated: CE deploys never enable the
EE blueprint chain so this module is unreachable there.
"""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AppListRow,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
def get(self):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
page_result = list_permitted_apps(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else None,
name=query.name,
)
if not page_result.app_ids:
env = PermittedExternalAppsListResponse(
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
)
return env.model_dump(mode="json"), 200
apps_by_id: dict[str, App] = {
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
}
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
items: list[AppListRow] = []
for app_id in page_result.app_ids:
app = apps_by_id.get(app_id)
if not app or app.status != "normal":
continue
tenant = tenants_by_id.get(str(app.tenant_id))
items.append(
AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[], # tenant-scoped; not surfaced cross-tenant
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=None, # cross-tenant author leak prevention
workspace_id=str(app.tenant_id),
workspace_name=tenant.name if tenant else None,
)
)
env = PermittedExternalAppsListResponse(
page=query.page,
limit=query.limit,
total=page_result.total,
has_more=query.page * query.limit < page_result.total,
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -1,3 +0,0 @@
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -1,46 +0,0 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline — they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -1,68 +0,0 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context — handlers
read populated values via the decorator's kwargs unpacking.
Context is intentionally decoupled from Flask's ``Request``: the pipeline
guard extracts whatever transport-level inputs the steps need (bearer
token, path params) at the boundary and writes them into Context fields,
so steps stay testable without a request object and won't leak coupling
to a specific framework.
"""
from __future__ import annotations
import uuid
from collections.abc import Mapping
from contextvars import Token
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from werkzeug.exceptions import Unauthorized
from libs.oauth_bearer import AuthContext, Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
required_scope: Scope
bearer_token: str | None = None
path_params: Mapping[str, str] = field(default_factory=dict)
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
auth_ctx_reset_token: Token[AuthContext] | None = None
@property
def must_tenant(self) -> Tenant:
if not self.tenant:
raise Unauthorized("tenant is not associated")
return self.tenant
@property
def must_subject_type(self) -> SubjectType:
if not self.subject_type:
raise Unauthorized("subject_type unset — BearerCheck did not run")
return self.subject_type
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -1,51 +0,0 @@
"""Pipeline IS the auth scheme.
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from functools import wraps
from flask import request
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
# Extract transport-level inputs at the boundary so steps
# stay decoupled from Flask's request object.
ctx = Context(
required_scope=scope,
bearer_token=extract_bearer(request),
path_params=dict(request.view_args or {}),
)
try:
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
finally:
if ctx.auth_ctx_reset_token is not None:
reset_auth_ctx(ctx.auth_ctx_reset_token)
return decorated
return decorator

View File

@ -1,170 +0,0 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also publishes the
resolved identity to the openapi auth ``ContextVar`` (the same one the
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
surface gate and any handler reading the request-scoped context has a single
source of truth across both auth-attach paths. The reset token is stashed
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
its `finally` so worker-thread reuse can't leak identity across requests.
"""
from __future__ import annotations
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
check_workspace_membership,
get_authenticator,
set_auth_ctx,
)
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also publishes the resolved `AuthContext` via
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
``validate_bearer`` writes — so the surface gate + downstream readers
don't see two different identity sources. The reset token is parked on
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
def __call__(self, ctx: Context) -> None:
if not ctx.bearer_token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(ctx.bearer_token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if the resolved subject is not in `accepted`."""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route — that is the design lock-in (no body /
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
``ctx.path_params`` at the boundary so this step doesn't need to know
about the request object.
"""
def __call__(self, ctx: Context) -> None:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only — SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.must_tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.must_subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -1,168 +0,0 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny — no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # type:ignore
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -1,89 +0,0 @@
"""Surface gate.
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
step) is the pipeline-level form. Both delegate to `check_surface` so the
audit emit + canonical-path message are single-sourced.
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
``openapi.wrong_surface_denied``.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from flask import request
from werkzeug.exceptions import Forbidden
from controllers.openapi._audit import emit_wrong_surface
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
_CANONICAL_PATH: dict[SubjectType, str] = {
SubjectType.ACCOUNT: "/openapi/v1/apps",
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
}
F = TypeVar("F", bound=Callable[..., object])
def check_surface(accepted: frozenset[SubjectType]) -> None:
"""Enforce that the resolved subject is in ``accepted``.
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
set the bearer layer didn't run — that's a wiring bug, not a
user-driven failure, so surface it as a ``RuntimeError`` instead of
a silent 403.
"""
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
)
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
if subject in accepted:
return
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
emit_wrong_surface(
subject_type=subject.value if subject else None,
attempted_path=request.path,
client_id=getattr(ctx, "client_id", None),
token_id=_stringify(getattr(ctx, "token_id", None)),
)
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
accepted_set: frozenset[SubjectType] = frozenset(accepted)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
check_surface(accepted_set)
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco
def _coerce_subject_type(raw: object) -> SubjectType | None:
if raw is None:
return None
if isinstance(raw, SubjectType):
return raw
if isinstance(raw, str):
return SubjectType(raw)
return None
def _stringify(value: object) -> str | None:
if value is None:
return None
return str(value)

View File

@ -1,72 +0,0 @@
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from werkzeug.exceptions import BadRequest
import services
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from extensions.ext_database import db
from fields.file_fields import FileResponse
from libs.oauth_bearer import Scope
from models import Account, App
from services.file_service import FileService
@openapi_ns.route("/apps/<string:app_id>/files/upload")
class AppFileUploadApi(Resource):
@openapi_ns.doc("upload_file_for_app_input")
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
@openapi_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request — no file or filename missing",
401: "Unauthorized — invalid or expired bearer token",
413: "File too large",
415: "Unsupported file type or blocked extension",
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.stream.read(),
mimetype=file.mimetype,
user=caller,
)
except ValueError as exc:
raise BadRequest(str(exc))
except services.errors.file.FileTooLargeError as exc:
raise FileTooLargeError(exc.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError(exc.description)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201

View File

@ -1,107 +0,0 @@
"""
OpenAPI bearer-authed human input form endpoints.
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
"""
from __future__ import annotations
import json
import logging
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
from libs.oauth_bearer import Scope
from models.model import App
from services.human_input_service import FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_openapi(form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
raise NotFound("Form not found")
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
submission_user_id: str | None = None
submission_end_user_id: str | None = None
if caller_kind == "account":
submission_user_id = caller.id
else:
submission_end_user_id = caller.id
if form.recipient_type is None:
logger.warning("Recipient type is None for form, form_token=%s", form_token)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=form.recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -1,9 +0,0 @@
from flask_restx import Resource
from controllers.openapi import openapi_ns
@openapi_ns.route("/_health")
class HealthApi(Resource):
def get(self):
return {"ok": True}

View File

@ -1,398 +0,0 @@
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
sub-groups in one module:
Protocol (RFC 8628, public + rate-limited):
POST /oauth/device/code
POST /oauth/device/token
GET /oauth/device/lookup
Approval (account branch, console-cookie authed):
POST /oauth/device/approve
POST /oauth/device/deny
SSO branch lives in oauth_device_sso.py.
"""
from __future__ import annotations
import logging
from typing import Any
from flask import request
from flask_login import login_required
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
LIMIT_DEVICE_CODE_PER_IP,
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
SlowDownDecision,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# =========================================================================
# Validation helpers
# =========================================================================
def _validate_json[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _validate_query[M: BaseModel](model: type[M]) -> M:
try:
return model.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
# =========================================================================
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
# =========================================================================
@openapi_ns.route("/oauth/device/code")
class OAuthDeviceCodeApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
payload = _validate_json(DeviceCodeRequest)
client_id = payload.client_id
device_label = payload.device_label
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
store = DeviceFlowRedis(redis_client)
ip = extract_remote_ip(request)
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
return {
"device_code": device_code,
"user_code": user_code,
"verification_uri": _verification_uri(),
"expires_in": expires_in,
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
}, 200
@openapi_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
def post(self):
payload = _validate_json(DevicePollRequest)
device_code = payload.device_code
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
@openapi_ns.route("/oauth/device/lookup")
class OAuthDeviceLookupApi(Resource):
"""Read-only — public for pre-validate before login. user_code is
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
"""
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
payload = _validate_query(DeviceLookupQuery)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
_device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
return {
"valid": True,
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
"client_id": state.client_id,
}, 200
# =========================================================================
# Approval endpoints — account branch (cookie-authed)
# =========================================================================
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
_APPROVE_GUARD_TTL_SECONDS = 10
@openapi_ns.route("/oauth/device/approve")
class DeviceApproveApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
# SET NX guard — without it, two in-flight approves both pass
# PENDING, both mint, and the second upsert silently rotates the
# first caller into an already-revoked token.
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
return {"error": "approve_in_progress"}, 409
try:
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=tenant)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=account.email,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
poll_payload = _build_account_poll_payload(account, tenant, mint)
try:
store.approve(
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError):
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.exception("device_flow: approve raced on %s", device_code)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
_emit_approve_audit(state, account, tenant, mint)
return {"status": "approved"}, 200
@openapi_ns.route("/oauth/device/deny")
class DeviceDenyApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
try:
store.deny(device_code)
except (StateNotFoundError, InvalidTransitionError):
logger.exception("device_flow: deny raced on %s", device_code)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
return {"status": "denied"}, 200
# =========================================================================
# Helpers
# =========================================================================
def _verification_uri() -> str:
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
if base:
return f"{base.rstrip('/')}/device"
return f"{request.host_url.rstrip('/')}/device"
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id,
state.created_ip,
poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w.id == str(tenant) for w in workspaces):
default_ws_id = str(tenant)
if default_ws_id is None:
for _t, m in rows:
if getattr(m, "current", False):
default_ws_id = str(m.tenant_id)
break
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0].id
payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
"workspaces": [w.model_dump(mode="json") for w in workspaces],
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
return payload
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id,
account.email,
state.client_id,
state.device_label,
mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"token_id": str(mint.token_id),
"subject_type": SubjectType.ACCOUNT,
"subject_email": account.email,
"account_id": str(account.id),
"tenant_id": tenant,
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["full"],
"expires_at": mint.expires_at.isoformat(),
},
)
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id,
state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
"client_id": state.client_id,
"device_label": state.device_label,
},
)

View File

@ -1,365 +0,0 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
EE-only. Browser flow:
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
Function-based (raw @bp.route) rather than Resource classes because the
handlers do redirects + cookie kwargs that don't fit the Resource shape.
"""
from __future__ import annotations
import logging
import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from pydantic import ValidationError
from werkzeug.exceptions import (
BadGateway,
BadRequest,
Conflict,
Forbidden,
NotFound,
Unauthorized,
)
from configs import dify_config
from controllers.openapi import bp
from controllers.openapi._models import ExtSubjectAssertionClaims
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import jws
from libs.device_flow_security import (
APPROVAL_GRANT_COOKIE_NAME,
ApprovalGrantClaims,
approval_grant_cleared_cookie_kwargs,
approval_grant_cookie_kwargs,
consume_approval_grant_nonce,
consume_sso_assertion_nonce,
enterprise_only,
mint_approval_grant,
verify_approval_grant,
)
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
from libs.rate_limit import (
LIMIT_APPROVE_EXT_PER_EMAIL,
LIMIT_SSO_INITIATE_PER_IP,
enforce,
rate_limit,
)
from services.account_service import AccountService
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
# device_code it references.
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
def _trusted_origin() -> str:
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
if not base:
raise BadGateway("console_api_url_unset")
return base
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
def sso_initiate():
user_code = (request.args.get("user_code") or "").strip().upper()
if not user_code:
raise BadRequest("user_code required")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise BadRequest("invalid_user_code")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
origin = _trusted_origin()
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
payload={
"redirect_url": "",
"app_code": "",
"intent": "device_flow",
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
)
try:
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
except Exception as e:
logger.warning("sso-initiate: enterprise call failed: %s", e)
raise BadGateway("sso_initiate_failed") from e
url = (reply or {}).get("url")
if not url:
raise BadGateway("sso_initiate_missing_url")
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
resp = redirect(url, code=302)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@bp.route("/oauth/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
keyset = jws.KeySet.from_shared_secret()
try:
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
try:
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
except ValidationError as e:
logger.warning("sso-complete: claim shape invalid: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
raise BadRequest("invalid_sso_assertion")
user_code = claims.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims.email):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
iss = _trusted_origin()
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims.email,
subject_issuer=claims.issuer,
user_code=user_code,
)
resp = redirect("/device?sso_verified=1", code=302)
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
return resp
@bp.route("/oauth/device/approval-context", methods=["GET"])
@enterprise_only
def approval_context():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("no_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify(
{
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}
), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@enterprise_only
def approve_external():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("invalid_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approve-external: bad cookie: %s", e)
raise Unauthorized("invalid_session") from e
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}
body_user_code = (data.get("user_code") or "").strip().upper()
if body_user_code != claims.user_code:
raise BadRequest("user_code_mismatch")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(claims.user_code)
if found is None:
raise NotFound("user_code_not_pending")
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
raise Forbidden("email_belongs_to_dify_account")
if not consume_approval_grant_nonce(redis_client, claims.nonce):
raise Unauthorized("session_already_consumed")
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=None)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=claims.subject_email,
subject_issuer=claims.subject_issuer,
account_id=None,
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
# SSO branch of the shared PollPayload contract: account/workspace
# fields are zero-filled (`None` / `[]`) for parity with the account
# branch in `oauth_device._build_account_poll_payload`.
poll_payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
"token_id": str(mint.token_id),
}
try:
store.approve(
device_code,
subject_email=claims.subject_email,
account_id=None,
subject_issuer=claims.subject_issuer,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError) as e:
logger.exception("approve-external: state transition raced")
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
resp = make_response(jsonify({"status": "approved"}), 200)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@dataclass(frozen=True)
class _RejectedClaims:
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
event without reaching for the full dataclass.
"""
subject_email: str
subject_issuer: str
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
logger.warning(
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
reason,
extra={
"audit": True,
"event": "oauth.device_flow_rejected",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"reason": reason,
"client_id": state.client_id,
"device_label": state.device_label,
},
)
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"token_id": str(mint.token_id),
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["apps:run"],
"expires_at": mint.expires_at.isoformat(),
},
)

View File

@ -1,119 +0,0 @@
"""
OpenAPI bearer-authed workflow reconnect event stream endpoint.
GET /apps/<app_id>/tasks/<task_id>/events
— reconnect to the SSE stream for a paused/running workflow run.
`task_id` is treated as `workflow_run_id`.
"""
from __future__ import annotations
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from libs.oauth_bearer import Scope
from models.enums import CreatorUserRole
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.response(200, "SSE event stream")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if caller_kind == "account":
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
else:
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=caller,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
else:
generator = WorkflowAppGenerator()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.OPENAPI,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -1,78 +0,0 @@
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
counterparts to the cookie-authed /console/api/workspaces endpoints.
Account bearers (dfoa_) see every tenant they're a member of. External
SSO bearers (dfoe_) have no account_id and so see an empty list — that
matches /openapi/v1/account.
"""
from __future__ import annotations
from itertools import starmap
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from models import Tenant, TenantAccountJoin
from services.account_service import TenantService
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = get_auth_ctx()
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = get_auth_ctx()
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
return WorkspaceSummaryResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
)
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
return WorkspaceDetailResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
)

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -79,10 +78,10 @@ class AnnotationReplyActionStatusApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App, job_id: UUID, action: str):
def get(self, app_model: App, job_id, action):
"""Get the status of an annotation reply action job."""
job_id_str = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
@ -90,10 +89,10 @@ class AnnotationReplyActionStatusApi(Resource):
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@service_api_ns.route("/apps/annotations")
@ -174,11 +173,11 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def put(self, app_model: App, annotation_id: UUID):
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -195,7 +194,7 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def delete(self, app_model: App, annotation_id: UUID):
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return "", 204

View File

@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -196,7 +195,7 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -225,7 +224,7 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -267,7 +266,7 @@ class ConversationVariablesApi(Resource):
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, c_id: UUID):
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
Conversational variables are only available for chat applications.
@ -313,7 +312,7 @@ class ConversationVariableDetailApi(Resource):
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def put(self, app_model: App, end_user: EndUser, c_id: UUID, variable_id: UUID):
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
Allows updating the value of a specific conversation variable.
@ -324,13 +323,13 @@ class ConversationVariableDetailApi(Resource):
raise NotChatAppError()
conversation_id = str(c_id)
variable_id_str = str(variable_id)
variable_id = str(variable_id)
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id_str, end_user, payload.value
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:

View File

@ -1,6 +1,5 @@
import logging
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -51,20 +50,20 @@ class FilePreviewApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
def get(self, app_model: App, end_user: EndUser, file_id: str):
"""
Preview/Download a file that was uploaded via Service API.
Provides secure file preview/download functionality.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
file_id_str = str(file_id)
file_id = str(file_id)
# Parse query parameters
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@ -7,6 +7,7 @@ paused human input forms in workflow/chatflow runs.
import json
import logging
from collections.abc import Sequence
from flask import Response
from flask_restx import Resource
@ -18,6 +19,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from graphon.nodes.human_input.entities import FormInputConfig
from libs.helper import to_timestamp
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
@ -28,11 +30,11 @@ logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
def _jsonify_form_definition(form: Form, *, inputs: Sequence[FormInputConfig] = ()) -> Response:
definition_payload = form.get_definition().model_dump(mode="json")
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"inputs": [form_input.model_dump(mode="json") for form_input in inputs],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
@ -75,7 +77,8 @@ class WorkflowHumanInputFormApi(Resource):
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
inputs = service.resolve_form_inputs(form)
return _jsonify_form_definition(form, inputs=inputs)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")

View File

@ -1,5 +1,4 @@
import logging
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -95,19 +94,19 @@ class MessageFeedbackApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
def post(self, app_model: App, end_user: EndUser, message_id):
"""Submit feedback for a message.
Allows users to rate messages as like/dislike and provide optional feedback content.
"""
message_id_str = str(message_id)
message_id = str(message_id)
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id_str,
message_id=message_id,
user=end_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
@ -160,19 +159,19 @@ class MessageSuggestedApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
def get(self, app_model: App, end_user: EndUser, message_id):
"""Get suggested follow-up questions for a message.
Returns AI-generated follow-up questions based on the message content.
"""
message_id_str = str(message_id)
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.SERVICE_API
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal, override
from typing import Literal
from dateutil.parser import isoparse
from flask import request
@ -76,13 +76,11 @@ def _enum_value(value):
class WorkflowRunStatusField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:

View File

@ -1,18 +1,13 @@
from typing import Any, Literal
from uuid import UUID
from typing import Any, Literal, cast
from flask import request
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -22,10 +17,9 @@ from controllers.service_api.wraps import (
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -125,21 +119,6 @@ class TagUnbindingPayload(BaseModel):
return self
class KnowledgeTagResponse(ResponseModel):
model_config = ConfigDict(coerce_numbers_to_str=True)
id: str
name: str
type: str
# TODO: The public Service API docs expose binding_count as string|null.
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
binding_count: str | None = None
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
pass
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
@ -148,29 +127,6 @@ class DatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
# todo: duplicate code, but the partial_member_list has different nullability
class DatasetListResponse(ResponseModel):
data: list[DatasetDetailResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetBoundTagResponse(ResponseModel):
id: str
name: str
class DatasetBoundTagListResponse(ResponseModel):
data: list[DatasetBoundTagResponse]
total: int
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@ -181,17 +137,9 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
register_response_schema_models(
service_api_ns,
SimpleResultResponse,
KnowledgeTagResponse,
KnowledgeTagListResponse,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetBoundTagListResponse,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
@service_api_ns.route("/datasets")
@ -206,18 +154,9 @@ class DatasetListApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
@service_api_ns.response(
200,
"Datasets retrieved successfully",
service_api_ns.models[DatasetListResponse.__name__],
)
def get(self, tenant_id):
"""Resource for getting datasets."""
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = DatasetListQuery.model_validate(query_params)
query = DatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
datasets, total = DatasetService.get_datasets(
@ -236,17 +175,17 @@ class DatasetListApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
else:
item["embedding_available"] = False
item["embedding_available"] = False # type: ignore
else:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@ -254,7 +193,7 @@ class DatasetListApi(DatasetApiResource):
"total": total,
"page": query.page,
}
return dump_response(DatasetListResponse, response), 200
return response, 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@ -266,11 +205,6 @@ class DatasetListApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200,
"Dataset created successfully",
service_api_ns.models[DatasetDetailResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
@ -314,7 +248,7 @@ class DatasetListApi(DatasetApiResource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return dump_response(DatasetDetailResponse, dataset), 200
return marshal(dataset, dataset_detail_fields), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>")
@ -332,12 +266,7 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset retrieved successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
def get(self, _, dataset_id: UUID):
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -346,7 +275,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = dump_response(DatasetDetailResponse, dataset)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -378,13 +307,7 @@ class DatasetApi(DatasetApiResource):
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return (
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
mode="json",
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
),
200,
)
return data, 200
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@ -398,13 +321,8 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset updated successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id: UUID):
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -453,7 +371,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = dump_response(DatasetDetailResponse, dataset)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
@ -466,7 +384,7 @@ class DatasetApi(DatasetApiResource):
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
return result_data, 200
@service_api_ns.doc("delete_dataset")
@service_api_ns.doc(description="Delete a dataset")
@ -480,7 +398,7 @@ class DatasetApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id: UUID):
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
@ -535,7 +453,7 @@ class DocumentStatusApi(DatasetApiResource):
400: "Bad request - invalid action",
}
)
def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
"""
Batch update document status.
@ -579,7 +497,7 @@ class DocumentStatusApi(DatasetApiResource):
except ValueError as e:
raise InvalidActionError(str(e))
return dump_response(SimpleResultResponse, {"result": "success"}), 200
return {"result": "success"}, 200
@service_api_ns.route("/datasets/tags")
@ -592,18 +510,14 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[KnowledgeTagListResponse.__name__],
)
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
return dump_response(KnowledgeTagListResponse, tags), 200
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@ -615,11 +529,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag created successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@ -629,10 +538,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@ -645,11 +553,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag updated successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -661,10 +564,9 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@ -749,11 +651,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[DatasetBoundTagListResponse.__name__],
)
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
@ -761,4 +658,5 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
response = {"data": tags_list, "total": len(tags)}
return response, 200

Some files were not shown because too many files have changed in this diff Show More