mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 20:07:46 +08:00
Compare commits
156 Commits
codex/alig
...
feat/hitl-
| Author | SHA1 | Date | |
|---|---|---|---|
| 44ae313dae | |||
| 67a17c9924 | |||
| 6696e49df3 | |||
| bd94db64b7 | |||
| f0fef15d2a | |||
| e49607f545 | |||
| 52db8b7e72 | |||
| 6cbc97590d | |||
| 49fa6aa975 | |||
| 418049090f | |||
| 1150f492aa | |||
| 3f8f7a9b87 | |||
| 205ea1bd98 | |||
| e26042cd90 | |||
| 8626e50ee6 | |||
| beafa0308e | |||
| 8d200afb32 | |||
| 41a88a6580 | |||
| 45c20a24f2 | |||
| 0b4348ee1c | |||
| d2c5539937 | |||
| ca9e738d53 | |||
| 423edab679 | |||
| e906ad1289 | |||
| 1c1970e47b | |||
| 450e19b46c | |||
| f69779fc0c | |||
| fdc20e460b | |||
| 68ea88c82d | |||
| f045359e67 | |||
| 4ff95853c2 | |||
| c96d108198 | |||
| c067498f73 | |||
| da90c934c0 | |||
| 226ce25596 | |||
| e897e275ef | |||
| 5f1b47c21b | |||
| 186fd60712 | |||
| 221a5020ad | |||
| 51b3e63472 | |||
| 070f8bcf14 | |||
| 698bb902ca | |||
| 9a65969744 | |||
| 7facbfcb71 | |||
| fce6d017d6 | |||
| f9729e8764 | |||
| b3c007461d | |||
| 3169ae7f53 | |||
| ac05dc9fa3 | |||
| 3f372352d8 | |||
| c86fd4cba0 | |||
| 26c14fd58a | |||
| 8bb70e78b7 | |||
| 94ad8674d6 | |||
| 998201f6e3 | |||
| 17759c0b80 | |||
| 3d445e1d95 | |||
| 1e6700b679 | |||
| 4b5b00ce63 | |||
| 07fea50216 | |||
| f1833fdb08 | |||
| e9070daaaa | |||
| 0e389d223f | |||
| 132f80dd9e | |||
| e099ba8679 | |||
| c4b2985361 | |||
| 343982bd46 | |||
| bba7001de2 | |||
| 62efb66a2f | |||
| e3e4f77de1 | |||
| a75208b432 | |||
| eadaaa1aa0 | |||
| 74c4d720d4 | |||
| c8d6ad117e | |||
| 7133754a31 | |||
| f1adc60822 | |||
| 58af8aa7fe | |||
| ed98925f11 | |||
| a9de4bd96b | |||
| d65cc21e85 | |||
| 23d39beeed | |||
| 3f35f3594b | |||
| 04ed797ac9 | |||
| 4c386e3ea7 | |||
| 3f6559dd60 | |||
| c0bedd9118 | |||
| b73a4d2700 | |||
| 02e42fd66f | |||
| a0f8db5516 | |||
| 37681bce8c | |||
| 23e59c6778 | |||
| 51e181c588 | |||
| d6f607f6e7 | |||
| cd91757623 | |||
| 651dfe5dca | |||
| 21a9c8d59c | |||
| 4fce9ee8e5 | |||
| 90ab734a05 | |||
| ad43a46c37 | |||
| e5b5c1fa3b | |||
| 6e4fa39db5 | |||
| e994009476 | |||
| 4da8afaed5 | |||
| bdecea34a3 | |||
| 9ad5d89b07 | |||
| 74f17d0ec8 | |||
| 7a12d46a45 | |||
| cec437b35b | |||
| 1c5d877372 | |||
| a8e663863d | |||
| 60f577fd11 | |||
| 46747993d4 | |||
| b8481f6d6f | |||
| 8ce356c98f | |||
| d72794bc67 | |||
| 7316a1be2b | |||
| 1a0776ce9b | |||
| 7c348e994c | |||
| 1a3c1a9b32 | |||
| 6263bb1b74 | |||
| d494e42166 | |||
| 5d20502494 | |||
| 75e5429534 | |||
| ccb43eb856 | |||
| 4f63b2b162 | |||
| 1b45cdb08a | |||
| cfb501c1f6 | |||
| 703228ffed | |||
| cbe2f66f1b | |||
| b879748ba0 | |||
| 29cad80e62 | |||
| 2ad35e56bc | |||
| add9260e58 | |||
| 94b8f8f170 | |||
| 67c832e60e | |||
| e9fb3bd751 | |||
| 80ede7cdb5 | |||
| 5309b56225 | |||
| 8d3ddee7d3 | |||
| c2fd595a82 | |||
| 85d05f5113 | |||
| ccf61c0372 | |||
| 49195fffdd | |||
| 4002d58171 | |||
| cae5315c18 | |||
| eb2eefdbb5 | |||
| 82d410325b | |||
| 00f0f6d040 | |||
| 022d73d0ed | |||
| 71803d7c76 | |||
| 93945d603e | |||
| acd1641d16 | |||
| 63711bf1dc | |||
| 37f79ee5d1 | |||
| 5faa4f9520 | |||
| 6c4f293719 |
@ -1,15 +0,0 @@
|
||||
**/node_modules
|
||||
**/.pnpm-store
|
||||
**/dist
|
||||
**/.next
|
||||
**/.turbo
|
||||
**/.cache
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/.mypy_cache
|
||||
**/.ruff_cache
|
||||
.git
|
||||
.github
|
||||
*.md
|
||||
!web/README.md
|
||||
!api/README.md
|
||||
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -5,7 +5,3 @@
|
||||
# them.
|
||||
|
||||
*.sh text eol=lf
|
||||
|
||||
# Codegen output must stay byte-identical across platforms so
|
||||
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
|
||||
*.generated.ts text eol=lf
|
||||
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,10 +18,6 @@
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# CLI
|
||||
/cli/ @langgenius/maintainers
|
||||
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
|
||||
111
.github/dependabot.yml
vendored
111
.github/dependabot.yml
vendored
@ -110,114 +110,3 @@ updates:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 5
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
88
.github/workflows/cli-release.yml
vendored
88
.github/workflows/cli-release.yml
vendored
@ -1,88 +0,0 @@
|
||||
name: CLI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- 'difyctl-v*'
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: build standalone binaries (all targets)
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Install cross-arch native prebuilds
|
||||
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||
# so `bun build --compile` can embed the right .node into each target.
|
||||
working-directory: ./
|
||||
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||
|
||||
- name: Compile standalone binaries (all targets)
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
run: |
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build:bin
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
run: scripts/release-write-checksums.sh
|
||||
|
||||
- name: Publish GitHub Release
|
||||
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||
with:
|
||||
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||
generate_release_notes: true
|
||||
fail_on_unmatched_files: true
|
||||
files: |
|
||||
cli/dist/bin/difyctl-v*
|
||||
60
.github/workflows/cli-smoke.yml
vendored
60
.github/workflows/cli-smoke.yml
vendored
@ -1,60 +0,0 @@
|
||||
name: CLI Smoke (live dify)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_version:
|
||||
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||
type: string
|
||||
required: true
|
||||
cli_ref:
|
||||
description: "Git ref to build the cli from (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
smoke:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Bring up dify
|
||||
env:
|
||||
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||
run: |
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||
docker compose up -d api worker web db redis
|
||||
for i in $(seq 1 60); do
|
||||
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||
echo "dify api ready after ${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run smoke against live dify
|
||||
working-directory: ./cli
|
||||
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||
|
||||
- name: Dump dify logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
cd docker
|
||||
docker compose logs api worker web --tail=200
|
||||
46
.github/workflows/cli-tests.yml
vendored
46
.github/workflows/cli-tests.yml
vendored
@ -1,46 +0,0 @@
|
||||
name: CLI Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
@ -63,18 +62,6 @@ jobs:
|
||||
- 'docker/generate_docker_compose'
|
||||
- 'docker/ssrf_proxy/**'
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
cli:
|
||||
- 'cli/**'
|
||||
- 'packages/tsconfig/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- 'eslint.config.mjs'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/cli-tests.yml'
|
||||
- '.github/workflows/cli-docker-build.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
@ -197,66 +184,6 @@ jobs:
|
||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
cli-tests-run:
|
||||
name: Run CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||
uses: ./.github/workflows/cli-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests-skip:
|
||||
name: Skip CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped CLI tests
|
||||
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
if: ${{ always() }}
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
- cli-tests-run
|
||||
- cli-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize CLI Tests status
|
||||
env:
|
||||
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||
run: |
|
||||
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests ran successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests were skipped because no CLI-related files changed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
web-tests-run:
|
||||
name: Run Web Tests
|
||||
needs:
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -115,12 +115,6 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||
!/cli/src/env/
|
||||
!/cli/src/commands/env/
|
||||
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||
!/cli/scripts/lib/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
@ -253,9 +247,8 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
*.local.toml
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.context/
|
||||
.context/*
|
||||
.eslintcache
|
||||
|
||||
@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
|
||||
@ -17,7 +17,7 @@ FROM base AS packages
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
git g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
|
||||
@ -159,7 +159,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -204,7 +203,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -30,7 +30,7 @@ from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
@ -42,7 +42,7 @@ from clients.agent_backend.request_builder import (
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
|
||||
@ -4,9 +4,7 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
|
||||
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
|
||||
protocol has a single owner. API-only context such as Agent Soul vs workflow job
|
||||
prompt is preserved in layer names and metadata until the dedicated product
|
||||
schemas land in later phases. Dify-owned execution identifiers are emitted as an
|
||||
explicit ``dify.execution_context`` layer so the run request stays fully
|
||||
composition-driven.
|
||||
schemas land in later phases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -17,19 +15,18 @@ from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLayerConfig,
|
||||
DifyPluginLLMLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
ExecutionContext,
|
||||
LayerExitSignals,
|
||||
RunComposition,
|
||||
RunLayerSpec,
|
||||
@ -40,15 +37,17 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
model_provider: str
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
@ -56,14 +55,10 @@ class AgentBackendModelConfig(BaseModel):
|
||||
|
||||
|
||||
class AgentBackendOutputConfig(BaseModel):
|
||||
"""API-side structured output declaration for the conventional output layer.
|
||||
|
||||
The structured-output tool name is fixed to ``final_output`` inside
|
||||
``dify_agent.layers.output`` so callers only control the JSON Schema plus
|
||||
optional description/strictness metadata.
|
||||
"""
|
||||
"""API-side structured output declaration for the conventional output layer."""
|
||||
|
||||
json_schema: dict[str, JsonValue]
|
||||
name: str = "final_result"
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
@ -74,7 +69,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
execution_context: ExecutionContext
|
||||
workflow_node_job_prompt: str
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
@ -126,18 +121,21 @@ class AgentBackendRunRequestBuilder:
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
config=DifyPluginLayerConfig(
|
||||
tenant_id=run_input.model.tenant_id,
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
user_id=run_input.model.user_id,
|
||||
),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
@ -155,6 +153,7 @@ class AgentBackendRunRequestBuilder:
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
name=run_input.output.name,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
@ -163,6 +162,7 @@ class AgentBackendRunRequestBuilder:
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
execution_context=run_input.execution_context,
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
|
||||
@ -11,7 +11,6 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -25,7 +23,7 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
|
||||
EDITION: str = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching tenant plugin model providers in Redis",
|
||||
default=60 * 60 * 24,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||
default=50 * 1024 * 1024,
|
||||
@ -525,44 +520,6 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
OPENAPI_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||
"programmatic clients. Set to true to activate; disabled by default."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -938,17 +895,6 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1235,14 +1181,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, override
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, override, runtime_checkable
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
@ -6,7 +6,7 @@ import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final, override
|
||||
from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
|
||||
@ -1,10 +1,40 @@
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
from pydantic import BaseModel, Field, JsonValue
|
||||
|
||||
HUMAN_INPUT_FORM_INPUT_EXAMPLE = {
|
||||
"decision": "approve",
|
||||
"attachment": {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "4e0d1b87-52f2-49f6-b8c6-95cd9c954b3e",
|
||||
"type": "document",
|
||||
},
|
||||
"attachments": [
|
||||
{
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "1a77f0df-c0e6-461c-987c-e72526f341ee",
|
||||
"type": "document",
|
||||
},
|
||||
{
|
||||
"transfer_method": "remote_url",
|
||||
"url": "https://example.com/report.pdf",
|
||||
"type": "document",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
inputs: dict[str, JsonValue] = Field(
|
||||
description=(
|
||||
"Submitted human input values keyed by output variable name. "
|
||||
"Use a string for paragraph or select input values, a file mapping for file inputs, "
|
||||
"and a list of file mappings for file-list inputs. Local file mappings use "
|
||||
"`transfer_method=local_file` with `upload_file_id`; remote file mappings use "
|
||||
"`transfer_method=remote_url` with `url` or `remote_url`."
|
||||
),
|
||||
examples=[HUMAN_INPUT_FORM_INPUT_EXAMPLE],
|
||||
)
|
||||
action: str
|
||||
|
||||
|
||||
|
||||
@ -68,7 +68,6 @@ from .app import (
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_node_output_inspector,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
@ -219,7 +218,6 @@ __all__ = [
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_node_output_inspector",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -82,7 +80,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
|
||||
|
||||
@ -91,7 +89,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, agent_id: UUID):
|
||||
def patch(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
return _agent_roster_service().update_roster_agent(
|
||||
@ -102,7 +100,7 @@ class AgentRosterDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, agent_id: UUID):
|
||||
def delete(self, agent_id):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
|
||||
return "", 204
|
||||
@ -113,7 +111,7 @@ class AgentRosterVersionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID):
|
||||
def get(self, agent_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
|
||||
|
||||
@ -123,7 +121,7 @@ class AgentRosterVersionDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, agent_id: UUID, version_id: UUID):
|
||||
def get(self, agent_id, version_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return _agent_roster_service().get_agent_version_detail(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import flask_restx
|
||||
from flask_restx import Resource
|
||||
@ -9,25 +8,18 @@ from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
@ -47,7 +39,7 @@ class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -71,11 +63,10 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
|
||||
|
||||
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
@ -83,14 +74,13 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@edit_permission_required
|
||||
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
|
||||
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
|
||||
|
||||
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
@ -117,7 +107,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -127,20 +117,9 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(
|
||||
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
|
||||
) -> tuple[str, int]:
|
||||
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
def _delete_api_key(
|
||||
self,
|
||||
resource_id: str,
|
||||
api_key_id: str,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> None:
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
@ -167,6 +146,8 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@ -174,21 +155,18 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for an app"""
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for an app"""
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -202,14 +180,9 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -222,21 +195,18 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
def get(self, resource_id): # type: ignore
|
||||
"""Get all API keys for a dataset"""
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
def post(self, resource_id): # type: ignore
|
||||
"""Create a new API key for a dataset"""
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
@ -250,14 +220,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@ -159,15 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id: UUID, annotation_setting_id: UUID):
|
||||
annotation_setting_id_str = str(annotation_setting_id)
|
||||
def post(self, app_id: UUID, annotation_setting_id):
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
|
||||
result = AppAnnotationService.update_app_annotation_setting(
|
||||
str(app_id), annotation_setting_id_str, setting_args
|
||||
)
|
||||
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
|
||||
return result, 200
|
||||
|
||||
|
||||
@ -183,9 +181,9 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id: UUID, job_id: UUID, action: str):
|
||||
job_id_str = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
|
||||
def get(self, app_id: UUID, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
@ -193,10 +191,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
|
||||
@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model, with_session
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -851,11 +852,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_session
|
||||
@get_app_model
|
||||
def get(self, session: Session, app_model: App):
|
||||
def get(self, app_model):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, import_id: str):
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
|
||||
@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id: str):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
@ -134,7 +133,7 @@ class CompletionConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
.distinct()
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@ -165,10 +164,10 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@ -182,12 +181,12 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id_str, current_user)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -272,7 +271,7 @@ class ChatConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
.distinct()
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
@ -318,10 +317,10 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
def get(self, app_model, conversation_id):
|
||||
conversation_id = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
_get_conversation(app_model, conversation_id), from_attributes=True
|
||||
).model_dump(mode="json")
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@ -335,12 +334,12 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id_str, current_user)
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -163,7 +162,7 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, server_id: UUID):
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -337,13 +336,13 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id: UUID):
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, message_id=message_id_str, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
@ -417,11 +416,11 @@ class MessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
select(Message).where(Message.id == message_id_str, Message.app_id == app_model.id).limit(1)
|
||||
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -346,15 +345,14 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -365,7 +363,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -392,11 +390,10 @@ class VariableApi(Resource):
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -437,15 +434,14 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -461,7 +457,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -472,11 +468,10 @@ class VariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
|
||||
@ -1,415 +0,0 @@
|
||||
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
|
||||
|
||||
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
|
||||
with a producer-organized view of each node's declared outputs and their
|
||||
per-run status. This module exposes two parallel sets of three read-only
|
||||
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
|
||||
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
|
||||
schedule / plugin triggers). Both sets share the same service code, the same
|
||||
response shapes, and the same error codes; the URL is the *only* difference,
|
||||
so the frontend can pick the right prefix based on which run-detail page the
|
||||
user is on.
|
||||
|
||||
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
|
||||
``published_run_inspector_not_implemented`` 404 code is therefore no longer
|
||||
produced.
|
||||
|
||||
URLs follow the design doc and reuse the existing
|
||||
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
|
||||
:mod:`controllers.console.app.workflow_draft_variable`. The
|
||||
``published`` prefix mirrors it shape-for-shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from services.workflow import inspector_events
|
||||
from services.workflow.node_output_inspector_service import (
|
||||
NodeOutputInspectorError,
|
||||
NodeOutputInspectorService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
|
||||
# intervening proxies (nginx, ingress) don't reap the idle connection.
|
||||
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
|
||||
_HEARTBEAT_EVERY_TICKS = 15
|
||||
# Hard ceiling on a single stream — if we never see a terminal workflow
|
||||
# event (engine crashed, redis dropped the message), force-close after this
|
||||
# many ticks (= seconds).
|
||||
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
|
||||
|
||||
|
||||
def _service() -> NodeOutputInspectorService:
|
||||
"""One-line factory so tests can monkeypatch a stub if needed."""
|
||||
return NodeOutputInspectorService()
|
||||
|
||||
|
||||
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
|
||||
"""Resource-body shared by draft + published snapshot endpoints.
|
||||
|
||||
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
|
||||
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
|
||||
behaviour lives here, where unit tests can hit it without spinning up
|
||||
Flask request context.
|
||||
"""
|
||||
try:
|
||||
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return snapshot.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
|
||||
"""Resource-body shared by draft + published node-detail endpoints."""
|
||||
try:
|
||||
view = _service().node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return view.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
|
||||
"""Resource-body shared by draft + published output-preview endpoints."""
|
||||
try:
|
||||
preview = _service().output_preview(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
output_name=output_name,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return preview.model_dump(mode="json")
|
||||
|
||||
|
||||
class _InspectorNotFound(BaseHTTPException):
|
||||
"""404 that preserves the inspector's specific error code.
|
||||
|
||||
Without this the response body collapses to a generic ``not_found`` code
|
||||
and clients lose the ability to distinguish, e.g.,
|
||||
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
|
||||
"""
|
||||
|
||||
code = 404
|
||||
|
||||
def __init__(self, error: NodeOutputInspectorError) -> None:
|
||||
self.error_code = error.code
|
||||
super().__init__(description=str(error))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowDraftRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot organized by producer node."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowDraftRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output (with signed URL for file refs)."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# SSE event stream — shared generator used by draft + published variants
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
|
||||
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
|
||||
|
||||
``data`` is JSON-serialized when given as a dict; raw strings are
|
||||
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
|
||||
"""
|
||||
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
|
||||
"""Yield SSE-framed strings for one workflow run.
|
||||
|
||||
The stream begins with a full ``snapshot`` event so the client has a
|
||||
starting state without needing a separate REST GET. Then for every
|
||||
``node_changed`` message from the pub/sub channel we re-read that node
|
||||
from DB and push a fresh ``node_changed`` event. When the workflow run
|
||||
reaches a terminal state we push one final ``workflow_run_completed``
|
||||
event and close the stream.
|
||||
|
||||
Failures inside the loop are caught and surfaced as ``error`` events so
|
||||
the frontend can show a banner rather than seeing the connection drop
|
||||
silently. The Inspector never raises across the SSE boundary.
|
||||
"""
|
||||
service = _service()
|
||||
run_id_str = str(run_id)
|
||||
|
||||
# Initial snapshot — also flushes a 404 back at the client right away
|
||||
# if the run is gone (raised before yielding any bytes, so Flask turns it
|
||||
# into the normal HTTP 404 path).
|
||||
try:
|
||||
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
|
||||
event_id = 0
|
||||
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
|
||||
|
||||
# If the run already finished by the time the client connected, emit
|
||||
# the terminal envelope synchronously and close — no point subscribing.
|
||||
# The enum value for partial success is the hyphenated ``partial-succeeded``
|
||||
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
|
||||
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Live subscription
|
||||
ticks_since_heartbeat = 0
|
||||
total_ticks = 0
|
||||
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
|
||||
total_ticks += 1
|
||||
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
|
||||
logger.warning(
|
||||
"Inspector SSE: forcing close after %ds without terminal event for run %s",
|
||||
_STREAM_HARD_TIMEOUT_TICKS,
|
||||
run_id_str,
|
||||
)
|
||||
return
|
||||
|
||||
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
|
||||
# ``node_changed`` message with both fields ``None`` on every redis
|
||||
# timeout. Real ``workflow_completed`` messages keep their kind even
|
||||
# when status couldn't be resolved (publisher race), so checking kind
|
||||
# first makes the heartbeat branch safe.
|
||||
if message.kind == "node_changed" and message.node_id is None and message.status is None:
|
||||
ticks_since_heartbeat += 1
|
||||
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
|
||||
yield ":keepalive\n\n"
|
||||
ticks_since_heartbeat = 0
|
||||
continue
|
||||
ticks_since_heartbeat = 0
|
||||
|
||||
if message.kind == "workflow_completed":
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# node_changed: recompute the node slice from DB
|
||||
if not message.node_id:
|
||||
continue
|
||||
try:
|
||||
node_view = service.node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=run_id_str,
|
||||
node_id=message.node_id,
|
||||
)
|
||||
except NodeOutputInspectorError:
|
||||
# Node may not appear in the graph yet (race with persistence); skip.
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Inspector SSE: node_detail failed for run %s node %s",
|
||||
run_id_str,
|
||||
message.node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"error",
|
||||
{"node_id": message.node_id, "message": "failed to refresh node detail"},
|
||||
event_id,
|
||||
)
|
||||
continue
|
||||
|
||||
event_id += 1
|
||||
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowDraftRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a draft run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_draft_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Published-run endpoints — symmetric to the draft trio above
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowPublishedRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot for a *published* workflow run.
|
||||
|
||||
Same response shape as the ``/draft/`` variant — frontend can multiplex
|
||||
based on which page (Composer test-run vs. Run History) is mounted.
|
||||
"""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status (published run)."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
|
||||
"/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output of a published run."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output of a published run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a published run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_published_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -189,7 +188,7 @@ class WorkflowRunExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id: str):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
@ -368,14 +367,14 @@ class WorkflowRunDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id_str)
|
||||
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
@ -397,17 +396,17 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = workflow_run_service.get_workflow_run_node_executions(
|
||||
app_model=app_model,
|
||||
run_id=run_id_str,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
||||
@ -1,38 +1,16 @@
|
||||
"""Controller decorators for console app resources.
|
||||
|
||||
`with_session` opens one SQLAlchemy session for a request handler and injects it
|
||||
as the first argument after `self`. Handlers use a transaction by default so
|
||||
migrated write paths keep commit/rollback handling; pure read handlers may opt
|
||||
out with `write=False`. App-loading decorators prefer that injected session when
|
||||
present, while still supporting existing handlers that have not been migrated
|
||||
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, cast, overload
|
||||
from typing import overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(session: Session, app_id: str) -> App | None:
|
||||
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
|
||||
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
@ -45,63 +23,6 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R],
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
|
||||
|
||||
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R] | None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> (
|
||||
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
|
||||
):
|
||||
"""Inject a request-scoped session, using a transaction only for write handlers."""
|
||||
|
||||
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if write:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
|
||||
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
candidate = args[1]
|
||||
if isinstance(candidate, Session):
|
||||
return candidate
|
||||
|
||||
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
|
||||
return cast(Session, candidate)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
@ -123,13 +44,6 @@ def get_app_model[**P, R](
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Inject the App model for handlers that receive an `app_id` path parameter.
|
||||
|
||||
New handlers may compose `@with_session` above this decorator so the app row
|
||||
is loaded through the same request-scoped session used by the controller.
|
||||
Existing handlers continue to work through `db.session` until migrated.
|
||||
"""
|
||||
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -141,11 +55,7 @@ def get_app_model[**P, R](
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
session = _get_injected_session(args)
|
||||
if session is None:
|
||||
app_model = _load_app_model_from_scoped_session(app_id)
|
||||
else:
|
||||
app_model = _load_app_model(session, app_id)
|
||||
app_model = _load_app_model(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
@ -42,8 +40,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
@ -69,9 +67,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
@ -89,9 +87,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
@ -159,15 +158,16 @@ class OAuthDataSourceSync(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider: str, binding_id: UUID):
|
||||
binding_id_str = str(binding_id)
|
||||
def get(self, provider, binding_id):
|
||||
provider = str(provider)
|
||||
binding_id = str(binding_id)
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id_str)
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
|
||||
@ -8,9 +8,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
@ -133,10 +133,12 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
|
||||
user_account_id = current_user.id
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
@ -48,6 +47,7 @@ class NotionEstimatePayload(BaseModel):
|
||||
class DataSourceNotionListQuery(BaseModel):
|
||||
dataset_id: str | None = Field(default=None, description="Dataset ID")
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
|
||||
|
||||
|
||||
class DataSourceNotionPreviewQuery(BaseModel):
|
||||
@ -204,6 +204,9 @@ class DataSourceNotionListApi(Resource):
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters = query.datasource_parameters or {}
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -251,7 +254,7 @@ class DataSourceNotionListApi(Resource):
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
@ -290,7 +293,7 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str):
|
||||
def get(self, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
|
||||
@ -303,11 +306,11 @@ class DataSourceNotionApi(Resource):
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
page_id_str = str(page_id)
|
||||
page_id = str(page_id)
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id="",
|
||||
notion_obj_id=page_id_str,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_tenant_id,
|
||||
@ -364,7 +367,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -382,7 +385,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -1,17 +1,15 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -32,10 +30,26 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
file_info_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
@ -47,6 +61,58 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
|
||||
|
||||
content_fields_copy = content_fields.copy()
|
||||
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
|
||||
content_model = get_or_create_model("DatasetContent", content_fields_copy)
|
||||
|
||||
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
|
||||
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
|
||||
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
|
||||
|
||||
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
@ -142,165 +208,9 @@ class ConsoleDatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetListItemResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str]
|
||||
|
||||
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetListItemResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetQueryFileInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetQueryContentResponse(ResponseModel):
|
||||
content_type: str
|
||||
content: str
|
||||
file_info: DatasetQueryFileInfoResponse | None = None
|
||||
|
||||
|
||||
class DatasetQueryDetailResponse(ResponseModel):
|
||||
id: str
|
||||
queries: list[DatasetQueryContentResponse]
|
||||
source: str
|
||||
source_app_id: str | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DatasetQueryListResponse(ResponseModel):
|
||||
data: list[DatasetQueryDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class RelatedAppResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
icon_url: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_icon_url(self) -> "RelatedAppResponse":
|
||||
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
|
||||
return self
|
||||
|
||||
|
||||
class RelatedAppListResponse(ResponseModel):
|
||||
data: list[RelatedAppResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DocumentStatusResponse(ResponseModel):
|
||||
id: str
|
||||
indexing_status: str
|
||||
processing_started_at: int | None
|
||||
parsing_completed_at: int | None
|
||||
cleaning_completed_at: int | None
|
||||
splitting_completed_at: int | None
|
||||
completed_at: int | None
|
||||
paused_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
@field_validator(
|
||||
"processing_started_at",
|
||||
"parsing_completed_at",
|
||||
"cleaning_completed_at",
|
||||
"splitting_completed_at",
|
||||
"completed_at",
|
||||
"paused_at",
|
||||
"stopped_at",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentStatusListResponse(ResponseModel):
|
||||
data: list[DocumentStatusResponse]
|
||||
|
||||
|
||||
class ErrorDocsResponse(DocumentStatusListResponse):
|
||||
total: int
|
||||
|
||||
|
||||
class IndexingEstimatePreviewItemResponse(ResponseModel):
|
||||
content: str
|
||||
child_chunks: list[str] | None = None
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class IndexingEstimateResponse(ResponseModel):
|
||||
total_segments: int
|
||||
preview: list[IndexingEstimatePreviewItemResponse]
|
||||
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
|
||||
|
||||
|
||||
class RetrievalSettingResponse(ResponseModel):
|
||||
retrieval_method: list[str]
|
||||
|
||||
|
||||
class PartialMemberListResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
class AutoDisableLogsResponse(ResponseModel):
|
||||
document_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetQueryListResponse,
|
||||
IndexingEstimateResponse,
|
||||
RelatedAppListResponse,
|
||||
DocumentStatusListResponse,
|
||||
ErrorDocsResponse,
|
||||
RetrievalSettingResponse,
|
||||
PartialMemberListResponse,
|
||||
AutoDisableLogsResponse,
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@ -383,8 +293,17 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
class DatasetListApi(Resource):
|
||||
@console_ns.doc("get_datasets")
|
||||
@console_ns.doc(description="Get list of datasets")
|
||||
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
|
||||
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"ids": "Filter by dataset IDs (list)",
|
||||
"keyword": "Search keyword",
|
||||
"tag_ids": "Filter by tag IDs (list)",
|
||||
"include_all": "Include all datasets (default: false)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Datasets retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -423,7 +342,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
|
||||
partial_members_map: dict[str, list[str]] = {}
|
||||
if dataset_ids:
|
||||
@ -460,12 +379,12 @@ class DatasetListApi(Resource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -494,7 +413,7 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -502,17 +421,13 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("get_dataset")
|
||||
@console_ns.doc(description="Get dataset details")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -522,7 +437,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@ -555,18 +470,14 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID):
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -595,7 +506,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
@ -614,7 +525,7 @@ class DatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Dataset deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
def delete(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@ -644,7 +555,7 @@ class DatasetUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
|
||||
@ -656,15 +567,11 @@ class DatasetQueryApi(Resource):
|
||||
@console_ns.doc("get_dataset_queries")
|
||||
@console_ns.doc(description="Get dataset query history")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Query history retrieved successfully",
|
||||
console_ns.models[DatasetQueryListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -682,24 +589,20 @@ class DatasetQueryApi(Resource):
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
"data": dataset_queries,
|
||||
"data": marshal(dataset_queries, dataset_query_detail_model),
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return dump_response(DatasetQueryListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/indexing-estimate")
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
@console_ns.doc("estimate_dataset_indexing")
|
||||
@console_ns.doc(description="Estimate dataset indexing cost")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing estimate calculated successfully",
|
||||
console_ns.models[IndexingEstimateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -796,15 +699,12 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@console_ns.doc("get_dataset_related_apps")
|
||||
@console_ns.doc(description="Get applications related to dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Related apps retrieved successfully",
|
||||
console_ns.models[RelatedAppListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
@marshal_with(related_app_list_model)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -824,7 +724,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
|
||||
@ -832,19 +732,15 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@console_ns.doc("get_dataset_indexing_status")
|
||||
@console_ns.doc(description="Get dataset indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing status retrieved successfully",
|
||||
console_ns.models[DocumentStatusListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
documents = db.session.scalars(
|
||||
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
|
||||
).all()
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
@ -882,8 +778,9 @@ class DatasetIndexingStatusApi(Resource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys")
|
||||
@ -952,15 +849,15 @@ class DatasetApiDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, api_key_id: UUID):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id_str = str(api_key_id)
|
||||
api_key_id = str(api_key_id)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id_str,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -985,7 +882,7 @@ class DatasetEnableApiApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, dataset_id: UUID, status: str):
|
||||
def post(self, dataset_id, status):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
|
||||
@ -1010,18 +907,13 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting")
|
||||
@console_ns.doc(description="Get dataset retrieval settings")
|
||||
@console_ns.response(
|
||||
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
|
||||
)
|
||||
@console_ns.response(200, "Retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
|
||||
)
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
@ -1029,19 +921,12 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting_mock")
|
||||
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@console_ns.doc(params={"vector_type": "Vector store type"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Mock retrieval settings retrieved successfully",
|
||||
console_ns.models[RetrievalSettingResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type: str):
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
|
||||
)
|
||||
def get(self, vector_type):
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
@ -1049,19 +934,19 @@ class DatasetErrorDocs(Resource):
|
||||
@console_ns.doc("get_dataset_error_docs")
|
||||
@console_ns.doc(description="Get dataset error documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
|
||||
@console_ns.response(200, "Error documents retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
@ -1069,17 +954,13 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@console_ns.doc("get_dataset_permission_users")
|
||||
@console_ns.doc(description="Get dataset permission user list")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Permission users retrieved successfully",
|
||||
console_ns.models[PartialMemberListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Permission users retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -1092,7 +973,9 @@ class DatasetPermissionUserListApi(Resource):
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
|
||||
return {
|
||||
"data": partial_members_list,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
|
||||
@ -1100,18 +983,14 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
@console_ns.doc("get_dataset_auto_disable_logs")
|
||||
@console_ns.doc(description="Get dataset auto disable logs")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Auto disable logs retrieved successfully",
|
||||
console_ns.models[AutoDisableLogsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(200, "Auto disable logs retrieved successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
|
||||
@ -5,7 +5,6 @@ from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
@ -316,9 +315,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
page = param.page
|
||||
@ -343,7 +342,7 @@ class DatasetDocumentListApi(Resource):
|
||||
)
|
||||
except (ArgumentTypeError, ValueError, Exception):
|
||||
fetch = False
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -352,7 +351,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
|
||||
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
@ -373,7 +372,7 @@ class DatasetDocumentListApi(Resource):
|
||||
sa.select(
|
||||
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
|
||||
)
|
||||
.where(DocumentSegment.dataset_id == dataset_id_str)
|
||||
.where(DocumentSegment.dataset_id == str(dataset_id))
|
||||
.group_by(DocumentSegment.document_id)
|
||||
.subquery()
|
||||
)
|
||||
@ -445,11 +444,11 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -473,7 +472,7 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -491,9 +490,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Documents deleted successfully")
|
||||
def delete(self, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
def delete(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
@ -583,11 +582,11 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -625,7 +624,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
"English",
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
return estimate_response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
@ -648,10 +647,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
def get(self, dataset_id, batch):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -725,7 +725,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
"English",
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
@ -745,9 +745,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
def get(self, dataset_id, batch):
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -799,16 +800,16 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -817,7 +818,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -860,10 +861,10 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -872,7 +873,7 @@ class DocumentApi(DocumentResource):
|
||||
if metadata == "only":
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -906,7 +907,7 @@ class DocumentApi(DocumentResource):
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -949,16 +950,16 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -979,7 +980,7 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
@ -996,16 +997,16 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: str):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset_id_str,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
@ -1043,11 +1044,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1091,11 +1092,11 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
def put(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1140,10 +1141,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1178,16 +1179,16 @@ class DocumentPauseApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document paused successfully")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""pause document."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
@ -1213,14 +1214,14 @@ class DocumentRecoverApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document resumed successfully")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID):
|
||||
def patch(self, dataset_id, document_id):
|
||||
"""recover document."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
if document is None:
|
||||
@ -1246,11 +1247,11 @@ class DocumentRetryApi(DocumentResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
|
||||
@console_ns.response(204, "Documents retry started successfully")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
retry_documents = []
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -1276,7 +1277,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
logger.exception("Failed to retry document, document id: %s", document_id)
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id_str, retry_documents)
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1288,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1300,7 +1301,7 @@ class DocumentRenameApi(DocumentResource):
|
||||
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name)
|
||||
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
@ -1313,15 +1314,15 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.tenant_id != current_tenant_id:
|
||||
@ -1332,7 +1333,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
# sync document
|
||||
DocumentService.sync_website_document(dataset_id_str, document)
|
||||
DocumentService.sync_website_document(dataset_id, document)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1342,19 +1343,19 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = db.session.scalar(
|
||||
select(DocumentPipelineExecutionLog)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id_str)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
@ -1391,7 +1392,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1400,10 +1401,10 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1437,7 +1438,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list)
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
@ -1451,7 +1452,7 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
if documents_to_update:
|
||||
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
|
||||
DocumentService.update_documents_need_summary(
|
||||
dataset_id=dataset_id_str,
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids_to_update,
|
||||
need_summary=True,
|
||||
)
|
||||
@ -1464,11 +1465,11 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task.delay(dataset_id_str, document.id)
|
||||
generate_summary_index_task.delay(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id_str,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
@ -1484,7 +1485,7 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1498,11 +1499,11 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -1516,8 +1517,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
result = SummaryIndexService.get_document_summary_status_detail(
|
||||
document_id=document_id_str,
|
||||
dataset_id=dataset_id_str,
|
||||
document_id=document_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
return result, 200
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
@ -115,12 +113,12 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
@ -129,7 +127,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
@ -150,7 +148,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
query = (
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
@ -203,9 +201,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if segment_ids:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
summary_records = SummaryIndexService.get_segments_summaries(
|
||||
segment_ids=segment_ids, dataset_id=dataset_id_str
|
||||
)
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
# Only include enabled summaries (already filtered by service)
|
||||
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
|
||||
|
||||
@ -230,19 +226,19 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_ids = request.args.getlist("segment_id")
|
||||
@ -266,15 +262,15 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check user's model setting
|
||||
@ -325,17 +321,17 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -365,7 +361,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -376,19 +372,19 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@ -408,10 +404,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -432,33 +428,33 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -487,17 +483,17 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
@ -521,8 +517,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
batch_create_segment_to_index_task.delay(
|
||||
str(job_id),
|
||||
upload_file_id,
|
||||
dataset_id_str,
|
||||
document_id_str,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
@ -534,7 +530,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id=None, dataset_id: UUID | None = None, document_id: UUID | None = None):
|
||||
def get(self, job_id=None, dataset_id=None, document_id=None):
|
||||
if job_id is None:
|
||||
raise NotFound("The job does not exist.")
|
||||
job_id = str(job_id)
|
||||
@ -555,24 +551,24 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -610,26 +606,26 @@ class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -646,9 +642,7 @@ class ChildChunkAddApi(Resource):
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
"total": child_chunks.total,
|
||||
@ -662,26 +656,26 @@ class ChildChunkAddApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
@ -711,39 +705,39 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
@ -768,39 +762,39 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id_str = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id_str = str(segment_id)
|
||||
segment_id = str(segment_id)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id_str),
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id_str,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
@ -10,12 +8,7 @@ from controllers.common.fields import UsageCountResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
@ -131,9 +124,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_tenant_id
|
||||
@account_initialization_required
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
@ -182,11 +175,11 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
@ -197,9 +190,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -207,7 +200,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id_str,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=payload.model_dump(),
|
||||
)
|
||||
|
||||
@ -217,14 +210,14 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id_str)
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -237,12 +230,12 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
|
||||
|
||||
@ -293,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
@ -119,7 +118,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -43,7 +42,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -63,7 +62,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@console_ns.response(
|
||||
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, dataset_id: UUID):
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -80,7 +79,7 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, metadata_id: UUID):
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
|
||||
name = payload.name
|
||||
@ -100,7 +99,7 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Metadata deleted successfully")
|
||||
def delete(self, dataset_id: UUID, metadata_id: UUID):
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@ -137,7 +136,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -165,7 +164,7 @@ class DocumentMetadataEditApi(Resource):
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@ -54,13 +54,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
yaml_content=payload.yaml_content,
|
||||
)
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
session.commit()
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_tenant_id,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
@ -169,22 +168,21 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -212,12 +210,11 @@ class RagPipelineVariableApi(Resource):
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
@ -253,16 +250,15 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
@ -271,7 +267,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -282,12 +278,11 @@ class RagPipelineVariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@ -67,12 +67,10 @@ class RagPipelineImportApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Use a plain Session so that caught exceptions inside the service
|
||||
# (which return FAILED status instead of re-raising) do not leave the
|
||||
# transaction in a closed state that a .begin() context manager cannot
|
||||
# handle. See app_import.py for the canonical pattern.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
@ -82,10 +80,6 @@ class RagPipelineImportApi(Resource):
|
||||
pipeline_id=payload.pipeline_id,
|
||||
dataset_name=payload.name,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
@ -105,17 +99,15 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_model)
|
||||
def post(self, import_id: str):
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@ -132,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
@ -150,7 +142,7 @@ class RagPipelineExportApi(Resource):
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
@ -876,14 +875,14 @@ class RagPipelineWorkflowRunDetailApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
def get(self, pipeline: Pipeline, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id_str)
|
||||
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id)
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
@ -901,17 +900,17 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
def get(self, pipeline: Pipeline, run_id: str):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
run_id_str = str(run_id)
|
||||
run_id = str(run_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
user = cast("Account | EndUser", current_user)
|
||||
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
|
||||
pipeline=pipeline,
|
||||
run_id=run_id_str,
|
||||
run_id=run_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@ -961,15 +960,15 @@ class RagPipelineTransformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset_id = str(dataset_id)
|
||||
rag_pipeline_transform_service = RagPipelineTransformService()
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
|
||||
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -41,10 +40,8 @@ register_schema_model(console_ns, TextToAudioPayload)
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
@ -84,10 +81,8 @@ class ChatAudioApi(InstalledAppResource):
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -83,10 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -135,10 +133,8 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -161,10 +157,8 @@ class CompletionStopApi(InstalledAppResource):
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -215,10 +209,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
def post(self, installed_app, task_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
@ -8,7 +7,6 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,7 +19,7 @@ from fields.conversation_fields import (
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -45,10 +43,8 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -95,10 +91,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def delete(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -120,10 +114,8 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def post(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -153,10 +145,8 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -179,10 +169,8 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def patch(self, installed_app, c_id):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
def delete(self, installed_app):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp):
|
||||
def patch(self, installed_app):
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@ -10,7 +9,6 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -22,16 +20,15 @@ from controllers.console.explore.error import (
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -61,11 +58,9 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -100,20 +95,18 @@ class MessageListApi(InstalledAppResource):
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def post(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
@ -130,15 +123,13 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
|
||||
|
||||
@ -148,7 +139,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
response = AppGenerateService.generate_more_like_this(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
streaming=streaming,
|
||||
)
|
||||
@ -178,20 +169,18 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def get(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=current_user, message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE
|
||||
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message not found")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
@ -7,14 +5,11 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models import Account
|
||||
from models.model import InstalledApp
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -25,11 +20,9 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -51,11 +44,9 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -74,17 +65,15 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def delete(self, installed_app, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id_str)
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -13,7 +13,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -26,7 +25,7 @@ from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -42,11 +41,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -9,14 +8,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class CodeBasedExtensionQuery(BaseModel):
|
||||
@ -116,11 +115,11 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@ -130,9 +129,9 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -153,12 +152,12 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, id: UUID):
|
||||
def get(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@ -169,9 +168,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, id: UUID):
|
||||
def post(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
@ -197,9 +196,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, id: UUID):
|
||||
def delete(self, id):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
|
||||
@ -2,13 +2,13 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -24,34 +24,11 @@ class FeatureApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
).model_dump()
|
||||
payload.pop("vector_space", None)
|
||||
return payload
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
|
||||
@console_ns.route("/features/vector-space")
|
||||
class FeatureVectorSpaceApi(Resource):
|
||||
@console_ns.doc("get_tenant_feature_vector_space")
|
||||
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[LimitationModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -22,13 +21,10 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
@ -65,8 +61,8 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
@ -110,10 +106,10 @@ class FilePreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -8,14 +8,8 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
@ -76,10 +70,11 @@ class NotificationApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self, current_user: Account):
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
@ -118,11 +113,11 @@ class NotificationDismissApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@ -12,13 +12,11 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -51,8 +49,7 @@ class RemoteFileUpload(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
|
||||
@login_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
|
||||
@ -77,11 +74,12 @@ class RemoteFileUpload(Resource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -9,16 +8,9 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
@ -99,8 +91,8 @@ class TagListApi(Resource):
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@ -116,9 +108,9 @@ class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Allow users with edit permission, or dataset editors (including dataset operators).
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -139,17 +131,17 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
def patch(self, tag_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id_str)
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = TagResponse.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
@ -162,27 +154,28 @@ class TagUpdateDeleteApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@console_ns.response(204, "Tag deleted successfully")
|
||||
def delete(self, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
def delete(self, tag_id):
|
||||
tag_id = str(tag_id)
|
||||
|
||||
TagService.delete_tag(tag_id_str)
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission(current_user: Account) -> None:
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
@ -195,8 +188,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
@ -219,9 +212,8 @@ class TagBindingCollectionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _create_tag_bindings(current_user)
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
@ -235,6 +227,5 @@ class TagBindingRemoveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _remove_tag_bindings(current_user)
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
from urllib import parse
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -23,15 +21,15 @@ from controllers.console.auth.error import (
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
is_allow_transfer_owner,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
@ -77,55 +75,7 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
|
||||
|
||||
|
||||
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(email.lower() for email in emails))
|
||||
|
||||
|
||||
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
|
||||
new_member_count = 0
|
||||
for email in emails:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
new_member_count += 1
|
||||
continue
|
||||
|
||||
exists = db.session.scalar(
|
||||
select(TenantAccountJoin.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not exists:
|
||||
new_member_count += 1
|
||||
|
||||
return new_member_count
|
||||
|
||||
|
||||
def _count_current_members(tenant_id: str) -> int:
|
||||
return (
|
||||
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
|
||||
)
|
||||
|
||||
|
||||
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
|
||||
if new_member_count <= 0:
|
||||
return
|
||||
|
||||
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
workspace_members = features.workspace_members
|
||||
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
return
|
||||
|
||||
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
|
||||
members = features.members
|
||||
current_member_count = _count_current_members(tenant_id)
|
||||
if 0 < members.limit < current_member_count + new_member_count:
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@ -154,11 +104,12 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
invitee_emails = _normalize_invitee_emails(args.emails)
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
@ -178,36 +129,37 @@ class MemberInviteEmailApi(Resource):
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
tenant_id = inviter.current_tenant.id
|
||||
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
|
||||
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
|
||||
_check_member_invite_limits(tenant_id, new_member_count)
|
||||
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
if not workspace_members.is_available(len(invitee_emails)):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@ -223,7 +175,7 @@ class MemberCancelInviteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, member_id: UUID):
|
||||
def delete(self, member_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -256,7 +208,7 @@ class MemberUpdateRoleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, member_id: UUID):
|
||||
def put(self, member_id):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberRoleUpdatePayload.model_validate(payload)
|
||||
new_role = args.role
|
||||
@ -399,7 +351,7 @@ class OwnerTransfer(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_allow_transfer_owner
|
||||
def post(self, member_id: UUID):
|
||||
def post(self, member_id):
|
||||
payload = console_ns.payload or {}
|
||||
args = OwnerTransferPayload.model_validate(payload)
|
||||
|
||||
|
||||
@ -532,7 +532,7 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, model_type: str):
|
||||
def get(self, model_type):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
@ -15,7 +15,6 @@ from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
|
||||
@ -166,10 +166,10 @@ class TenantListApi(Resource):
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@ -17,7 +16,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.encryption import FieldEncryption
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
@ -84,7 +82,9 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -96,28 +96,21 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(current_tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403,
|
||||
"The capacity of the knowledge storage space has reached the limit of your subscription.",
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -147,7 +140,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
@ -205,11 +198,15 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
if dify_config.BILLING_ENABLED and utm_info:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -298,7 +295,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
@ -312,6 +309,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object() # type: ignore
|
||||
if not isinstance(user, Account):
|
||||
@ -329,6 +327,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object()
|
||||
if not isinstance(user, Account) or not user.is_admin_or_owner:
|
||||
@ -496,25 +495,3 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_tenant_id[T, **P, R](
|
||||
view: Callable[Concatenate[T, str, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return view(self, current_tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -50,8 +49,8 @@ class ImagePreviewApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True))
|
||||
timestamp = args.timestamp
|
||||
@ -60,7 +59,7 @@ class ImagePreviewApi(Resource):
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
file_id=file_id_str,
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
@ -92,14 +91,14 @@ class FilePreviewApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id_str,
|
||||
file_id=file_id,
|
||||
timestamp=args.timestamp,
|
||||
nonce=args.nonce,
|
||||
sign=args.sign,
|
||||
@ -160,10 +159,10 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, workspace_id: UUID):
|
||||
workspace_id_str = str(workspace_id)
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
|
||||
custom_config = TenantService.get_custom_config(workspace_id_str)
|
||||
custom_config = TenantService.get_custom_config(workspace_id)
|
||||
webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
|
||||
|
||||
if not webapp_logo_file_id:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -46,19 +45,17 @@ class ToolFileApi(Resource):
|
||||
415: "Unsupported file type",
|
||||
}
|
||||
)
|
||||
def get(self, file_id: UUID, extension: str):
|
||||
file_id_str = str(file_id)
|
||||
def get(self, file_id, extension):
|
||||
file_id = str(file_id)
|
||||
|
||||
args = ToolFileQuery.model_validate(request.args.to_dict())
|
||||
if not verify_tool_file_signature(
|
||||
file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign
|
||||
):
|
||||
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager()
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id_str,
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not stream or not tool_file:
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
ServerVersionResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspacePayload,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
WorkflowRunData,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
AccountResponse,
|
||||
SessionRow,
|
||||
SessionListResponse,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
FileResponse,
|
||||
ServerVersionResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
_meta,
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
files,
|
||||
human_input_form,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_events,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
# Request models are imported from _models.py and registered above.
|
||||
|
||||
__all__ = [
|
||||
"_meta",
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
"files",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_events",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
@ -1,66 +0,0 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||
|
||||
|
||||
def emit_app_run(
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
caller_kind: str,
|
||||
mode: str,
|
||||
surface: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
surface,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
"surface": surface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_wrong_surface(
|
||||
*,
|
||||
subject_type: str | None,
|
||||
attempted_path: str,
|
||||
client_id: str | None,
|
||||
token_id: str | None,
|
||||
) -> None:
|
||||
logger.warning(
|
||||
"audit: %s subject_type=%s attempted_path=%s",
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
subject_type,
|
||||
attempted_path,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
"subject_type": subject_type,
|
||||
"attempted_path": attempted_path,
|
||||
"client_id": client_id,
|
||||
"token_id": token_id,
|
||||
},
|
||||
)
|
||||
@ -1,143 +0,0 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out: dict[str, Any] = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
|
||||
|
||||
Returns the server's project version and edition so the difyctl CLI can probe
|
||||
compatibility without needing to be logged in. Mirrors the `_health` endpoint
|
||||
in `index.py`.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import ServerVersionResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_version")
|
||||
class VersionApi(Resource):
|
||||
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||
def get(self):
|
||||
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||
return ServerVersionResponse(
|
||||
version=dify_config.project.version,
|
||||
edition=edition,
|
||||
).model_dump(mode="json")
|
||||
@ -1,344 +0,0 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class PermittedExternalAppsListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
class AccountPayload(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspacePayload(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
subject_type: str
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account: AccountPayload | None = None
|
||||
workspaces: list[WorkspacePayload] = []
|
||||
default_workspace_id: str | None = None
|
||||
|
||||
|
||||
class SessionRow(BaseModel):
|
||||
id: str
|
||||
prefix: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
created_at: str | None = None
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WorkspaceSummaryResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
workspaces: list[WorkspaceSummaryResponse]
|
||||
|
||||
|
||||
class WorkspaceDetailResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class DeviceCodeResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceLookupResponse(BaseModel):
|
||||
valid: bool
|
||||
expires_in_remaining: int = 0
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class DeviceMutateResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ServerVersionResponse(BaseModel):
|
||||
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
|
||||
|
||||
version: str
|
||||
edition: Literal["SELF_HOSTED", "CLOUD"]
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
workspace_id: UUIDStrOrEmpty | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class PermittedExternalAppsListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
|
||||
|
||||
|
||||
class ExtSubjectAssertionClaims(BaseModel):
|
||||
email: str = _EMAIL_FIELD
|
||||
issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
class ApprovalGrantClaimsPayload(BaseModel):
|
||||
subject_email: str = _EMAIL_FIELD
|
||||
subject_issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
@ -1,169 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.oauth_device_flow import (
|
||||
list_active_sessions,
|
||||
revoke_oauth_token,
|
||||
token_belongs_to_subject,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = list_active_sessions(db.session, ctx, now)
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
SessionRow(
|
||||
id=str(r.id),
|
||||
prefix=r.prefix,
|
||||
client_id=r.client_id,
|
||||
device_label=r.device_label,
|
||||
created_at=_iso(r.created_at),
|
||||
last_used_at=_iso(r.last_used_at),
|
||||
expires_at=_iso(r.expires_at),
|
||||
)
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
if not token_belongs_to_subject(db.session, session_id, ctx):
|
||||
raise NotFound("session not found")
|
||||
|
||||
revoke_oauth_token(db.session, redis_client, session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> WorkspacePayload:
|
||||
join, tenant = row
|
||||
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||
|
||||
|
||||
def _account_payload(account) -> AccountPayload:
|
||||
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||
@ -1,165 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
@ -1,270 +0,0 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
|
||||
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name = None
|
||||
if pagination.items:
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
|
||||
env = AppListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=cast(int, pagination.total),
|
||||
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,102 +0,0 @@
|
||||
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||
|
||||
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||
(public / sso_verified). License-gated: CE deploys never enable the
|
||||
EE blueprint chain so this module is unreachable there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
}
|
||||
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
|
||||
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=page_result.total,
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,3 +0,0 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
@ -1,46 +0,0 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
@ -1,68 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
@ -1,51 +0,0 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
@ -1,170 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
@ -1,168 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
@ -1,89 +0,0 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that the resolved subject is in ``accepted``.
|
||||
|
||||
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
|
||||
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
|
||||
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
|
||||
set the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
@ -1,72 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/files/upload")
|
||||
class AppFileUploadApi(Resource):
|
||||
@openapi_ns.doc("upload_file_for_app_input")
|
||||
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
|
||||
@openapi_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request — no file or filename missing",
|
||||
401: "Unauthorized — invalid or expired bearer token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type or blocked extension",
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=caller,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except services.errors.file.FileTooLargeError as exc:
|
||||
raise FileTooLargeError(exc.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError(exc.description)
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
@ -1,107 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed human input form endpoints.
|
||||
|
||||
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
|
||||
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App
|
||||
from services.human_input_service import FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
if caller_kind == "account":
|
||||
submission_user_id = caller.id
|
||||
else:
|
||||
submission_end_user_id = caller.id
|
||||
|
||||
if form.recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_token=%s", form_token)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
@ -1,9 +0,0 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
@ -1,398 +0,0 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Validation helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
@ -1,365 +0,0 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import bp
|
||||
from controllers.openapi._models import ExtSubjectAssertionClaims
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
def _trusted_origin() -> str:
|
||||
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
|
||||
if not base:
|
||||
raise BadGateway("console_api_url_unset")
|
||||
return base
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
origin = _trusted_origin()
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
try:
|
||||
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
|
||||
except ValidationError as e:
|
||||
logger.warning("sso-complete: claim shape invalid: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = claims.user_code.strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.email):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = _trusted_origin()
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims.email,
|
||||
subject_issuer=claims.issuer,
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
# SSO branch of the shared PollPayload contract: account/workspace
|
||||
# fields are zero-filled (`None` / `[]`) for parity with the account
|
||||
# branch in `oauth_device._build_account_poll_payload`.
|
||||
poll_payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RejectedClaims:
|
||||
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||
|
||||
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||
event without reaching for the full dataclass.
|
||||
"""
|
||||
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
reason,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_rejected",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"reason": reason,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
@ -1,119 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed workflow reconnect event stream endpoint.
|
||||
|
||||
GET /apps/<app_id>/tasks/<task_id>/events
|
||||
— reconnect to the SSE stream for a paused/running workflow run.
|
||||
`task_id` is treated as `workflow_run_id`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if caller_kind == "account":
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
else:
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=caller,
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
else:
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.OPENAPI,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -1,78 +0,0 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
)
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||
return WorkspaceDetailResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
)
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -79,10 +78,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, job_id: UUID, action: str):
|
||||
def get(self, app_model: App, job_id, action):
|
||||
"""Get the status of an annotation reply action job."""
|
||||
job_id_str = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
if cache_result is None:
|
||||
raise ValueError("The job does not exist.")
|
||||
@ -90,10 +89,10 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
job_status = cache_result.decode()
|
||||
error_msg = ""
|
||||
if job_status == "error":
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
|
||||
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
|
||||
error_msg = redis_client.get(app_annotation_error_key).decode()
|
||||
|
||||
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations")
|
||||
@ -174,11 +173,11 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: UUID):
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -195,7 +194,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: UUID):
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return "", 204
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -196,7 +195,7 @@ class ConversationDetailApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -225,7 +224,7 @@ class ConversationRenameApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -267,7 +266,7 @@ class ConversationVariablesApi(Resource):
|
||||
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""List all variables for a conversation.
|
||||
|
||||
Conversational variables are only available for chat applications.
|
||||
@ -313,7 +312,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
service_api_ns.models[ConversationVariableResponse.__name__],
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def put(self, app_model: App, end_user: EndUser, c_id: UUID, variable_id: UUID):
|
||||
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
|
||||
"""Update a conversation variable's value.
|
||||
|
||||
Allows updating the value of a specific conversation variable.
|
||||
@ -324,13 +323,13 @@ class ConversationVariableDetailApi(Resource):
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
variable = ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id_str, end_user, payload.value
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -51,20 +50,20 @@ class FilePreviewApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API.
|
||||
|
||||
Provides secure file preview/download functionality.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
file_id_str = str(file_id)
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
||||
@ -7,6 +7,7 @@ paused human input forms in workflow/chatflow runs.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
@ -18,6 +19,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from graphon.nodes.human_input.entities import FormInputConfig
|
||||
from libs.helper import to_timestamp
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
@ -28,11 +30,11 @@ logger = logging.getLogger(__name__)
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
def _jsonify_form_definition(form: Form, *, inputs: Sequence[FormInputConfig] = ()) -> Response:
|
||||
definition_payload = form.get_definition().model_dump(mode="json")
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"inputs": [form_input.model_dump(mode="json") for form_input in inputs],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
@ -75,7 +77,8 @@ class WorkflowHumanInputFormApi(Resource):
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
inputs = service.resolve_form_inputs(form)
|
||||
return _jsonify_form_definition(form, inputs=inputs)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -95,19 +94,19 @@ class MessageFeedbackApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Submit feedback for a message.
|
||||
|
||||
Allows users to rate messages as like/dislike and provide optional feedback content.
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id_str,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
@ -160,19 +159,19 @@ class MessageSuggestedApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
"""Get suggested follow-up questions for a message.
|
||||
|
||||
Returns AI-generated follow-up questions based on the message content.
|
||||
"""
|
||||
message_id_str = str(message_id)
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.SERVICE_API
|
||||
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal, override
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -76,13 +76,11 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -1,18 +1,13 @@
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_enum_models,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@ -22,10 +17,9 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -125,21 +119,6 @@ class TagUnbindingPayload(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class KnowledgeTagResponse(ResponseModel):
|
||||
model_config = ConfigDict(coerce_numbers_to_str=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
# TODO: The public Service API docs expose binding_count as string|null.
|
||||
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
|
||||
binding_count: str | None = None
|
||||
|
||||
|
||||
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
@ -148,29 +127,6 @@ class DatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
# todo: duplicate code, but the partial_member_list has different nullability
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetBoundTagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class DatasetBoundTagListResponse(ResponseModel):
|
||||
data: list[DatasetBoundTagResponse]
|
||||
total: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
@ -181,17 +137,9 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
SimpleResultResponse,
|
||||
KnowledgeTagResponse,
|
||||
KnowledgeTagListResponse,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetBoundTagListResponse,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets")
|
||||
@ -206,18 +154,9 @@ class DatasetListApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Datasets retrieved successfully",
|
||||
service_api_ns.models[DatasetListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = DatasetListQuery.model_validate(query_params)
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict())
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
@ -236,17 +175,17 @@ class DatasetListApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = False
|
||||
item["embedding_available"] = False # type: ignore
|
||||
else:
|
||||
item["embedding_available"] = True
|
||||
item["embedding_available"] = True # type: ignore
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -254,7 +193,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@ -266,11 +205,6 @@ class DatasetListApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset created successfully",
|
||||
service_api_ns.models[DatasetDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
@ -314,7 +248,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return dump_response(DatasetDetailResponse, dataset), 200
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -332,12 +266,7 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
def get(self, _, dataset_id: UUID):
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -346,7 +275,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
# check embedding setting
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -378,13 +307,7 @@ class DatasetApi(DatasetApiResource):
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return (
|
||||
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
|
||||
mode="json",
|
||||
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
|
||||
),
|
||||
200,
|
||||
)
|
||||
return data, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@ -398,13 +321,8 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id: UUID):
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -453,7 +371,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -466,7 +384,7 @@ class DatasetApi(DatasetApiResource):
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
|
||||
return result_data, 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@ -480,7 +398,7 @@ class DatasetApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, _, dataset_id: UUID):
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
|
||||
@ -535,7 +453,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
400: "Bad request - invalid action",
|
||||
}
|
||||
)
|
||||
def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
"""
|
||||
Batch update document status.
|
||||
|
||||
@ -579,7 +497,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
@ -592,18 +510,14 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[KnowledgeTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -615,11 +529,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag created successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -629,10 +538,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
|
||||
)
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -645,11 +553,6 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag updated successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -661,10 +564,9 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
|
||||
)
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@ -749,11 +651,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[DatasetBoundTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
@ -761,4 +658,5 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user