mirror of
https://github.com/langgenius/dify.git
synced 2026-05-24 02:47:53 +08:00
Compare commits
76 Commits
codex/refi
...
feat/cli
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d7ee1d761 | |||
| a831920803 | |||
| 98de360447 | |||
| 95816a26b8 | |||
| f39e7d6cd5 | |||
| a6970bc144 | |||
| fecdef6c21 | |||
| b7a2347291 | |||
| 0c1b37687f | |||
| 341a82bf1e | |||
| e71df18d72 | |||
| 152f916768 | |||
| 9b3b408849 | |||
| 102643e060 | |||
| 4c2ba50dfe | |||
| 3df1042706 | |||
| 0f39ac8960 | |||
| 102a9f3eb3 | |||
| d94e302045 | |||
| 2ff07b6311 | |||
| 1554d80df5 | |||
| 7ec50f4656 | |||
| 66c4b9d589 | |||
| cb218f2832 | |||
| ed6a079582 | |||
| f1d68e4178 | |||
| 851bf36f24 | |||
| f6e4d558a6 | |||
| c38c5d375e | |||
| 5381452de9 | |||
| ed5f6b153f | |||
| 6f760a3901 | |||
| 0bf64ca3f2 | |||
| e827aca154 | |||
| 8de813c867 | |||
| d09d360530 | |||
| 8a4c87234f | |||
| 31ea69be66 | |||
| 6de46024a3 | |||
| 0d5173f73f | |||
| fd1ebdd6cb | |||
| 9fe7adaf69 | |||
| 7a6c84dca3 | |||
| 75a8120152 | |||
| ca103b60cc | |||
| 2c90cfa00f | |||
| 6851624dbe | |||
| 44d1b66c93 | |||
| f372eb8e5b | |||
| 36101c7126 | |||
| fe212003b1 | |||
| 948214fe6a | |||
| 14328634b5 | |||
| de0a44be06 | |||
| 6153a6b663 | |||
| d5dee5326e | |||
| 49b33647e7 | |||
| badfd7689a | |||
| 0ff00e742f | |||
| a89b43bccc | |||
| c6792ce415 | |||
| 8918142ce1 | |||
| e2d6ae818c | |||
| 2fd7b82970 | |||
| 1cc7953f79 | |||
| 31cf656b35 | |||
| 8be6665d22 | |||
| c2b91d849d | |||
| e0f4e98a2f | |||
| 9d554495cf | |||
| c2868075fa | |||
| 1a83dfaf1f | |||
| 83d14e0540 | |||
| 1f7da9c191 | |||
| b21d0ae32d | |||
| 6779366dca |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
**/node_modules
|
||||||
|
**/.pnpm-store
|
||||||
|
**/dist
|
||||||
|
**/.next
|
||||||
|
**/.turbo
|
||||||
|
**/.cache
|
||||||
|
**/__pycache__
|
||||||
|
**/*.pyc
|
||||||
|
**/.mypy_cache
|
||||||
|
**/.ruff_cache
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
*.md
|
||||||
|
!web/README.md
|
||||||
|
!api/README.md
|
||||||
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -5,3 +5,7 @@
|
|||||||
# them.
|
# them.
|
||||||
|
|
||||||
*.sh text eol=lf
|
*.sh text eol=lf
|
||||||
|
|
||||||
|
# Codegen output must stay byte-identical across platforms so
|
||||||
|
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
|
||||||
|
*.generated.ts text eol=lf
|
||||||
|
|||||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,6 +18,10 @@
|
|||||||
# Docs
|
# Docs
|
||||||
/docs/ @crazywoola
|
/docs/ @crazywoola
|
||||||
|
|
||||||
|
# CLI
|
||||||
|
/cli/ @langgenius/maintainers
|
||||||
|
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||||
|
|
||||||
# Backend (default owner, more specific rules below will override)
|
# Backend (default owner, more specific rules below will override)
|
||||||
/api/ @QuantumGhost
|
/api/ @QuantumGhost
|
||||||
|
|
||||||
|
|||||||
88
.github/workflows/cli-release.yml
vendored
Normal file
88
.github/workflows/cli-release.yml
vendored
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
name: CLI Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'difyctl-v*'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: cli-release-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
name: build standalone binaries (all targets)
|
||||||
|
runs-on: depot-ubuntu-24.04
|
||||||
|
if: github.repository == 'langgenius/dify'
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: ./cli
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup web environment
|
||||||
|
uses: ./.github/actions/setup-web
|
||||||
|
|
||||||
|
- name: Setup Bun
|
||||||
|
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||||
|
with:
|
||||||
|
bun-version: latest
|
||||||
|
|
||||||
|
- name: Read cli/package.json
|
||||||
|
id: manifest
|
||||||
|
run: |
|
||||||
|
version=$(node -p "require('./package.json').version")
|
||||||
|
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||||
|
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||||
|
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||||
|
{
|
||||||
|
echo "version=$version"
|
||||||
|
echo "channel=$channel"
|
||||||
|
echo "minDify=$minDify"
|
||||||
|
echo "maxDify=$maxDify"
|
||||||
|
} >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Validate manifest
|
||||||
|
run: scripts/release-validate-manifest.sh
|
||||||
|
|
||||||
|
- name: Install cross-arch native prebuilds
|
||||||
|
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||||
|
# so `bun build --compile` can embed the right .node into each target.
|
||||||
|
working-directory: ./
|
||||||
|
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||||
|
|
||||||
|
- name: Compile standalone binaries (all targets)
|
||||||
|
env:
|
||||||
|
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||||
|
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||||
|
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||||
|
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||||
|
run: |
|
||||||
|
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||||
|
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||||
|
pnpm build:bin
|
||||||
|
|
||||||
|
- name: Generate sha256 checksum file
|
||||||
|
env:
|
||||||
|
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||||
|
run: scripts/release-write-checksums.sh
|
||||||
|
|
||||||
|
- name: Publish GitHub Release
|
||||||
|
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||||
|
with:
|
||||||
|
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||||
|
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||||
|
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||||
|
generate_release_notes: true
|
||||||
|
fail_on_unmatched_files: true
|
||||||
|
files: |
|
||||||
|
cli/dist/bin/difyctl-v*
|
||||||
60
.github/workflows/cli-smoke.yml
vendored
Normal file
60
.github/workflows/cli-smoke.yml
vendored
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
name: CLI Smoke (live dify)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
dify_version:
|
||||||
|
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
cli_ref:
|
||||||
|
description: "Git ref to build the cli from (default: current branch)"
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
smoke:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 30
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
steps:
|
||||||
|
- name: Checkout cli ref
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.cli_ref || github.ref }}
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Setup web environment
|
||||||
|
uses: ./.github/actions/setup-web
|
||||||
|
|
||||||
|
- name: Bring up dify
|
||||||
|
env:
|
||||||
|
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||||
|
run: |
|
||||||
|
cd docker
|
||||||
|
cp .env.example .env
|
||||||
|
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||||
|
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||||
|
docker compose up -d api worker web db redis
|
||||||
|
for i in $(seq 1 60); do
|
||||||
|
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||||
|
echo "dify api ready after ${i}s"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Run smoke against live dify
|
||||||
|
working-directory: ./cli
|
||||||
|
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||||
|
|
||||||
|
- name: Dump dify logs on failure
|
||||||
|
if: failure()
|
||||||
|
run: |
|
||||||
|
cd docker
|
||||||
|
docker compose logs api worker web --tail=200
|
||||||
46
.github/workflows/cli-tests.yml
vendored
Normal file
46
.github/workflows/cli-tests.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
name: CLI Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
secrets:
|
||||||
|
CODECOV_TOKEN:
|
||||||
|
required: false
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: CLI Tests
|
||||||
|
runs-on: depot-ubuntu-24.04
|
||||||
|
env:
|
||||||
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
shell: bash
|
||||||
|
working-directory: ./cli
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Setup web environment
|
||||||
|
uses: ./.github/actions/setup-web
|
||||||
|
|
||||||
|
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||||
|
run: pnpm ci
|
||||||
|
|
||||||
|
- name: Report coverage
|
||||||
|
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||||
|
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||||
|
with:
|
||||||
|
directory: cli/coverage
|
||||||
|
flags: cli
|
||||||
|
env:
|
||||||
|
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,6 +42,7 @@ jobs:
|
|||||||
runs-on: depot-ubuntu-24.04
|
runs-on: depot-ubuntu-24.04
|
||||||
outputs:
|
outputs:
|
||||||
api-changed: ${{ steps.changes.outputs.api }}
|
api-changed: ${{ steps.changes.outputs.api }}
|
||||||
|
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||||
web-changed: ${{ steps.changes.outputs.web }}
|
web-changed: ${{ steps.changes.outputs.web }}
|
||||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||||
@ -62,6 +63,18 @@ jobs:
|
|||||||
- 'docker/generate_docker_compose'
|
- 'docker/generate_docker_compose'
|
||||||
- 'docker/ssrf_proxy/**'
|
- 'docker/ssrf_proxy/**'
|
||||||
- 'docker/volumes/sandbox/conf/**'
|
- 'docker/volumes/sandbox/conf/**'
|
||||||
|
cli:
|
||||||
|
- 'cli/**'
|
||||||
|
- 'packages/tsconfig/**'
|
||||||
|
- 'package.json'
|
||||||
|
- 'pnpm-lock.yaml'
|
||||||
|
- 'pnpm-workspace.yaml'
|
||||||
|
- 'eslint.config.mjs'
|
||||||
|
- '.npmrc'
|
||||||
|
- '.nvmrc'
|
||||||
|
- '.github/workflows/cli-tests.yml'
|
||||||
|
- '.github/workflows/cli-docker-build.yml'
|
||||||
|
- '.github/actions/setup-web/**'
|
||||||
web:
|
web:
|
||||||
- 'web/**'
|
- 'web/**'
|
||||||
- 'packages/**'
|
- 'packages/**'
|
||||||
@ -184,6 +197,66 @@ jobs:
|
|||||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
|
cli-tests-run:
|
||||||
|
name: Run CLI Tests
|
||||||
|
needs:
|
||||||
|
- pre_job
|
||||||
|
- check-changes
|
||||||
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||||
|
uses: ./.github/workflows/cli-tests.yml
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
|
cli-tests-skip:
|
||||||
|
name: Skip CLI Tests
|
||||||
|
needs:
|
||||||
|
- pre_job
|
||||||
|
- check-changes
|
||||||
|
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||||
|
runs-on: depot-ubuntu-24.04
|
||||||
|
steps:
|
||||||
|
- name: Report skipped CLI tests
|
||||||
|
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||||
|
|
||||||
|
cli-tests:
|
||||||
|
name: CLI Tests
|
||||||
|
if: ${{ always() }}
|
||||||
|
needs:
|
||||||
|
- pre_job
|
||||||
|
- check-changes
|
||||||
|
- cli-tests-run
|
||||||
|
- cli-tests-skip
|
||||||
|
runs-on: depot-ubuntu-24.04
|
||||||
|
steps:
|
||||||
|
- name: Finalize CLI Tests status
|
||||||
|
env:
|
||||||
|
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||||
|
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||||
|
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||||
|
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||||
|
run: |
|
||||||
|
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||||
|
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||||
|
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||||
|
echo "CLI tests ran successfully."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||||
|
echo "CLI tests were skipped because no CLI-related files changed."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||||
|
exit 1
|
||||||
|
|
||||||
web-tests-run:
|
web-tests-run:
|
||||||
name: Run Web Tests
|
name: Run Web Tests
|
||||||
needs:
|
needs:
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -115,6 +115,12 @@ venv/
|
|||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
|
||||||
|
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||||
|
!/cli/src/env/
|
||||||
|
!/cli/src/commands/env/
|
||||||
|
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||||
|
!/cli/scripts/lib/
|
||||||
.conda/
|
.conda/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
@ -247,6 +253,7 @@ scripts/stress-test/reports/
|
|||||||
# settings
|
# settings
|
||||||
*.local.json
|
*.local.json
|
||||||
*.local.md
|
*.local.md
|
||||||
|
*.local.toml
|
||||||
|
|
||||||
# Code Agent Folder
|
# Code Agent Folder
|
||||||
.qoder/*
|
.qoder/*
|
||||||
|
|||||||
@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
|
|||||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
|
||||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||||
|
|
||||||
# Marketplace configuration
|
# Marketplace configuration
|
||||||
|
|||||||
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_logstore,
|
ext_logstore,
|
||||||
ext_mail,
|
ext_mail,
|
||||||
ext_migrate,
|
ext_migrate,
|
||||||
|
ext_oauth_bearer,
|
||||||
ext_orjson,
|
ext_orjson,
|
||||||
ext_otel,
|
ext_otel,
|
||||||
ext_proxy_fix,
|
ext_proxy_fix,
|
||||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
|||||||
ext_enterprise_telemetry,
|
ext_enterprise_telemetry,
|
||||||
ext_request_logging,
|
ext_request_logging,
|
||||||
ext_session_factory,
|
ext_session_factory,
|
||||||
|
ext_oauth_bearer,
|
||||||
]
|
]
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
short_name = ext.__name__.split(".")[-1]
|
short_name = ext.__name__.split(".")[-1]
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from configs import dify_config
|
|||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.utils.system_encryption import encrypt_system_params
|
from core.tools.utils.system_encryption import encrypt_system_params
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Tenant
|
from models import Tenant
|
||||||
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
|||||||
from models.tools import ToolOAuthSystemClient
|
from models.tools import ToolOAuthSystemClient
|
||||||
from services.plugin.data_migration import PluginDataMigration
|
from services.plugin.data_migration import PluginDataMigration
|
||||||
from services.plugin.plugin_migration import PluginMigration
|
from services.plugin.plugin_migration import PluginMigration
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
@ -23,7 +25,7 @@ class DeploymentConfig(BaseSettings):
|
|||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
EDITION: str = Field(
|
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
|
||||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||||
default="SELF_HOSTED",
|
default="SELF_HOSTED",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
|
|||||||
default=60 * 60,
|
default=60 * 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
|
|
||||||
description="TTL in seconds for caching tenant plugin model providers in Redis",
|
|
||||||
default=60 * 60 * 24,
|
|
||||||
)
|
|
||||||
|
|
||||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||||
default=50 * 1024 * 1024,
|
default=50 * 1024 * 1024,
|
||||||
@ -525,6 +520,44 @@ class HttpConfig(BaseSettings):
|
|||||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||||
|
|
||||||
|
OPENAPI_ENABLED: bool = Field(
|
||||||
|
description=(
|
||||||
|
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||||
|
"programmatic clients. Set to true to activate; disabled by default."
|
||||||
|
),
|
||||||
|
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||||
|
description=(
|
||||||
|
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||||
|
"Default empty = same-origin only. Browser-cookie routes within "
|
||||||
|
"the group reject cross-origin OPTIONS regardless of this list."
|
||||||
|
),
|
||||||
|
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||||
|
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||||
|
|
||||||
|
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||||
|
description=(
|
||||||
|
"Comma-separated client_id values accepted at "
|
||||||
|
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||||
|
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||||
|
),
|
||||||
|
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||||
|
default="difyctl",
|
||||||
|
)
|
||||||
|
|
||||||
|
@computed_field # type: ignore[misc]
|
||||||
|
@property
|
||||||
|
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||||
|
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||||
|
|
||||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||||
)
|
)
|
||||||
@ -900,6 +933,17 @@ class AuthConfig(BaseSettings):
|
|||||||
default=86400,
|
default=86400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ENABLE_OAUTH_BEARER: bool = Field(
|
||||||
|
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||||
|
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||||
|
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||||
|
default=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModerationConfig(BaseSettings):
|
class ModerationConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@ -1186,6 +1230,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||||||
description="Enable scheduled workflow run cleanup task",
|
description="Enable scheduled workflow run cleanup task",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||||
|
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||||
|
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||||
|
default=30,
|
||||||
|
)
|
||||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||||
description="Enable mail clean document notify task",
|
description="Enable mail clean document notify task",
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, override
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
|
|||||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||||
|
|
||||||
@override
|
|
||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
if not isinstance(self.remote_configs, dict):
|
if not isinstance(self.remote_configs, dict):
|
||||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, override
|
from typing import Any
|
||||||
|
|
||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to parse config: {e}")
|
raise RuntimeError(f"Failed to parse config: {e}")
|
||||||
|
|
||||||
@override
|
|
||||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
field_value = self.remote_configs.get(field_name)
|
field_value = self.remote_configs.get(field_name)
|
||||||
if field_value is None:
|
if field_value is None:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import threading
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import Callable, Generator
|
||||||
from contextlib import AbstractContextManager, contextmanager
|
from contextlib import AbstractContextManager, contextmanager
|
||||||
from typing import Any, Protocol, final, override, runtime_checkable
|
from typing import Any, Protocol, final, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
|
|||||||
self._config = config or {}
|
self._config = config or {}
|
||||||
self._extensions: dict[str, Any] = {}
|
self._extensions: dict[str, Any] = {}
|
||||||
|
|
||||||
@override
|
|
||||||
def get_config(self, key: str, default: Any = None) -> Any:
|
def get_config(self, key: str, default: Any = None) -> Any:
|
||||||
"""Get configuration value by key."""
|
"""Get configuration value by key."""
|
||||||
return self._config.get(key, default)
|
return self._config.get(key, default)
|
||||||
|
|
||||||
@override
|
|
||||||
def get_extension(self, name: str) -> Any:
|
def get_extension(self, name: str) -> Any:
|
||||||
"""Get extension by name."""
|
"""Get extension by name."""
|
||||||
return self._extensions.get(name)
|
return self._extensions.get(name)
|
||||||
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
|
|||||||
self._extensions[name] = extension
|
self._extensions[name] = extension
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@override
|
|
||||||
def enter(self) -> Generator[None, None, None]:
|
def enter(self) -> Generator[None, None, None]:
|
||||||
"""Enter null context (no-op)."""
|
"""Enter null context (no-op)."""
|
||||||
yield
|
yield
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import contextvars
|
|||||||
import threading
|
import threading
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, final, override
|
from typing import Any, final
|
||||||
|
|
||||||
from flask import Flask, current_app, g
|
from flask import Flask, current_app, g
|
||||||
|
|
||||||
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
|
|||||||
"""
|
"""
|
||||||
self._flask_app = flask_app
|
self._flask_app = flask_app
|
||||||
|
|
||||||
@override
|
|
||||||
def get_config(self, key: str, default: Any = None) -> Any:
|
def get_config(self, key: str, default: Any = None) -> Any:
|
||||||
"""Get configuration value from Flask app config."""
|
"""Get configuration value from Flask app config."""
|
||||||
return self._flask_app.config.get(key, default)
|
return self._flask_app.config.get(key, default)
|
||||||
|
|
||||||
@override
|
|
||||||
def get_extension(self, name: str) -> Any:
|
def get_extension(self, name: str) -> Any:
|
||||||
"""Get Flask extension by name."""
|
"""Get Flask extension by name."""
|
||||||
return self._flask_app.extensions.get(name)
|
return self._flask_app.extensions.get(name)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@override
|
|
||||||
def enter(self) -> Generator[None, None, None]:
|
def enter(self) -> Generator[None, None, None]:
|
||||||
"""Enter Flask app context."""
|
"""Enter Flask app context."""
|
||||||
with self._flask_app.app_context():
|
with self._flask_app.app_context():
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from controllers.console import console_ns
|
|||||||
from controllers.console.workspace import plugin_permission_required
|
from controllers.console.workspace import plugin_permission_required
|
||||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from fields.base import ResponseModel
|
from fields.base import ResponseModel
|
||||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from libs.login import current_account_with_tenant, login_required
|
from libs.login import current_account_with_tenant, login_required
|
||||||
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
|
|||||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
|
|
||||||
class ParserList(BaseModel):
|
class ParserList(BaseModel):
|
||||||
|
|||||||
128
api/controllers/openapi/__init__.py
Normal file
128
api/controllers/openapi/__init__.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
from flask import Blueprint
|
||||||
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
from libs.device_flow_security import attach_anti_framing
|
||||||
|
from libs.external_api import ExternalApi
|
||||||
|
|
||||||
|
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||||
|
attach_anti_framing(bp)
|
||||||
|
|
||||||
|
api = ExternalApi(
|
||||||
|
bp,
|
||||||
|
version="1.0",
|
||||||
|
title="OpenAPI",
|
||||||
|
description="User-scoped programmatic API (bearer auth)",
|
||||||
|
)
|
||||||
|
|
||||||
|
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||||
|
|
||||||
|
# Register response/query models BEFORE importing controller modules so that
|
||||||
|
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||||
|
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||||
|
from controllers.openapi._models import (
|
||||||
|
AccountPayload,
|
||||||
|
AccountResponse,
|
||||||
|
AppDescribeInfo,
|
||||||
|
AppDescribeQuery,
|
||||||
|
AppDescribeResponse,
|
||||||
|
AppInfoResponse,
|
||||||
|
AppListQuery,
|
||||||
|
AppListResponse,
|
||||||
|
AppListRow,
|
||||||
|
AppRunRequest,
|
||||||
|
DeviceCodeRequest,
|
||||||
|
DeviceCodeResponse,
|
||||||
|
DeviceLookupQuery,
|
||||||
|
DeviceLookupResponse,
|
||||||
|
DeviceMutateRequest,
|
||||||
|
DeviceMutateResponse,
|
||||||
|
DevicePollRequest,
|
||||||
|
MessageMetadata,
|
||||||
|
PermittedExternalAppsListQuery,
|
||||||
|
PermittedExternalAppsListResponse,
|
||||||
|
RevokeResponse,
|
||||||
|
ServerVersionResponse,
|
||||||
|
SessionListResponse,
|
||||||
|
SessionRow,
|
||||||
|
TagItem,
|
||||||
|
UsageInfo,
|
||||||
|
WorkflowRunData,
|
||||||
|
WorkspaceDetailResponse,
|
||||||
|
WorkspaceListResponse,
|
||||||
|
WorkspacePayload,
|
||||||
|
WorkspaceSummaryResponse,
|
||||||
|
)
|
||||||
|
from fields.file_fields import FileResponse
|
||||||
|
|
||||||
|
register_schema_models(
|
||||||
|
openapi_ns,
|
||||||
|
AppDescribeQuery,
|
||||||
|
AppListQuery,
|
||||||
|
AppRunRequest,
|
||||||
|
DeviceCodeRequest,
|
||||||
|
DevicePollRequest,
|
||||||
|
DeviceLookupQuery,
|
||||||
|
DeviceMutateRequest,
|
||||||
|
PermittedExternalAppsListQuery,
|
||||||
|
)
|
||||||
|
register_response_schema_models(
|
||||||
|
openapi_ns,
|
||||||
|
TagItem,
|
||||||
|
UsageInfo,
|
||||||
|
MessageMetadata,
|
||||||
|
AppListRow,
|
||||||
|
AppListResponse,
|
||||||
|
AppInfoResponse,
|
||||||
|
AppDescribeInfo,
|
||||||
|
AppDescribeResponse,
|
||||||
|
WorkflowRunData,
|
||||||
|
AccountPayload,
|
||||||
|
WorkspacePayload,
|
||||||
|
AccountResponse,
|
||||||
|
SessionRow,
|
||||||
|
SessionListResponse,
|
||||||
|
PermittedExternalAppsListResponse,
|
||||||
|
RevokeResponse,
|
||||||
|
WorkspaceSummaryResponse,
|
||||||
|
WorkspaceListResponse,
|
||||||
|
WorkspaceDetailResponse,
|
||||||
|
DeviceCodeResponse,
|
||||||
|
DeviceLookupResponse,
|
||||||
|
DeviceMutateResponse,
|
||||||
|
FileResponse,
|
||||||
|
ServerVersionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import (
|
||||||
|
_meta,
|
||||||
|
account,
|
||||||
|
app_run,
|
||||||
|
apps,
|
||||||
|
apps_permitted_external,
|
||||||
|
files,
|
||||||
|
human_input_form,
|
||||||
|
index,
|
||||||
|
oauth_device,
|
||||||
|
oauth_device_sso,
|
||||||
|
workflow_events,
|
||||||
|
workspaces,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request models are imported from _models.py and registered above.
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"_meta",
|
||||||
|
"account",
|
||||||
|
"app_run",
|
||||||
|
"apps",
|
||||||
|
"apps_permitted_external",
|
||||||
|
"files",
|
||||||
|
"human_input_form",
|
||||||
|
"index",
|
||||||
|
"oauth_device",
|
||||||
|
"oauth_device_sso",
|
||||||
|
"workflow_events",
|
||||||
|
"workspaces",
|
||||||
|
]
|
||||||
|
|
||||||
|
api.add_namespace(openapi_ns)
|
||||||
66
api/controllers/openapi/_audit.py
Normal file
66
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""Audit emission for openapi app-run endpoints.
|
||||||
|
|
||||||
|
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||||
|
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||||
|
its own allowlist to decide whether to ship the line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||||
|
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||||
|
|
||||||
|
|
||||||
|
def emit_app_run(
|
||||||
|
*,
|
||||||
|
app_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
caller_kind: str,
|
||||||
|
mode: str,
|
||||||
|
surface: str,
|
||||||
|
) -> None:
|
||||||
|
logger.info(
|
||||||
|
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||||
|
EVENT_APP_RUN_OPENAPI,
|
||||||
|
app_id,
|
||||||
|
tenant_id,
|
||||||
|
caller_kind,
|
||||||
|
mode,
|
||||||
|
surface,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": EVENT_APP_RUN_OPENAPI,
|
||||||
|
"app_id": app_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"caller_kind": caller_kind,
|
||||||
|
"mode": mode,
|
||||||
|
"surface": surface,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def emit_wrong_surface(
|
||||||
|
*,
|
||||||
|
subject_type: str | None,
|
||||||
|
attempted_path: str,
|
||||||
|
client_id: str | None,
|
||||||
|
token_id: str | None,
|
||||||
|
) -> None:
|
||||||
|
logger.warning(
|
||||||
|
"audit: %s subject_type=%s attempted_path=%s",
|
||||||
|
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||||
|
subject_type,
|
||||||
|
attempted_path,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||||
|
"subject_type": subject_type,
|
||||||
|
"attempted_path": attempted_path,
|
||||||
|
"client_id": client_id,
|
||||||
|
"token_id": token_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
143
api/controllers/openapi/_input_schema.py
Normal file
143
api/controllers/openapi/_input_schema.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||||
|
|
||||||
|
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||||
|
"$schema": JSON_SCHEMA_DRAFT,
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||||
|
|
||||||
|
|
||||||
|
def _file_object_shape() -> dict[str, Any]:
|
||||||
|
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {"type": "string"},
|
||||||
|
"transfer_method": {"type": "string"},
|
||||||
|
"url": {"type": "string"},
|
||||||
|
"upload_file_id": {"type": "string"},
|
||||||
|
},
|
||||||
|
"additionalProperties": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
label = row.get("label") or row.get("variable", "")
|
||||||
|
base: dict[str, Any] = {"title": label} if label else {}
|
||||||
|
|
||||||
|
if row_type in ("text-input", "paragraph"):
|
||||||
|
out: dict[str, Any] = {"type": "string"} | base
|
||||||
|
max_length = row.get("max_length")
|
||||||
|
if isinstance(max_length, int) and max_length > 0:
|
||||||
|
out["maxLength"] = max_length
|
||||||
|
return out
|
||||||
|
|
||||||
|
if row_type == "select":
|
||||||
|
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||||
|
|
||||||
|
if row_type == "number":
|
||||||
|
return {"type": "number"} | base
|
||||||
|
|
||||||
|
if row_type == "file":
|
||||||
|
return _file_object_shape() | base
|
||||||
|
|
||||||
|
if row_type == "file-list":
|
||||||
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": _file_object_shape(),
|
||||||
|
} | base
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||||
|
"""Translate a user_input_form row list into (properties, required-list).
|
||||||
|
|
||||||
|
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||||
|
Unknown variable types are skipped (forward-compat).
|
||||||
|
"""
|
||||||
|
properties: dict[str, Any] = {}
|
||||||
|
required: list[str] = []
|
||||||
|
for row in form:
|
||||||
|
if not isinstance(row, dict) or len(row) != 1:
|
||||||
|
continue
|
||||||
|
((row_type, row_body),) = row.items()
|
||||||
|
if not isinstance(row_body, dict):
|
||||||
|
continue
|
||||||
|
variable = row_body.get("variable")
|
||||||
|
if not variable:
|
||||||
|
continue
|
||||||
|
schema = _row_to_schema(row_type, row_body)
|
||||||
|
if schema is None:
|
||||||
|
continue
|
||||||
|
properties[variable] = schema
|
||||||
|
if row_body.get("required"):
|
||||||
|
required.append(variable)
|
||||||
|
return properties, required
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||||
|
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||||
|
|
||||||
|
Raises `AppUnavailableError` on misconfigured apps.
|
||||||
|
"""
|
||||||
|
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||||
|
workflow = app.workflow
|
||||||
|
if workflow is None:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
return (
|
||||||
|
workflow.features_dict,
|
||||||
|
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
app_model_config = app.app_model_config
|
||||||
|
if app_model_config is None:
|
||||||
|
raise AppUnavailableError()
|
||||||
|
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||||
|
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||||
|
|
||||||
|
|
||||||
|
def build_input_schema(app: App) -> dict[str, Any]:
|
||||||
|
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||||
|
|
||||||
|
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||||
|
completion / workflow: `inputs` object only.
|
||||||
|
Raises `AppUnavailableError` on misconfigured apps.
|
||||||
|
"""
|
||||||
|
_, user_input_form = resolve_app_config(app)
|
||||||
|
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||||
|
|
||||||
|
properties: dict[str, Any] = {}
|
||||||
|
required: list[str] = []
|
||||||
|
|
||||||
|
if app.mode in _CHAT_FAMILY:
|
||||||
|
properties["query"] = {"type": "string", "minLength": 1}
|
||||||
|
required.append("query")
|
||||||
|
|
||||||
|
properties["inputs"] = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": inputs_props,
|
||||||
|
"required": inputs_required,
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
required.append("inputs")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"$schema": JSON_SCHEMA_DRAFT,
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required,
|
||||||
|
}
|
||||||
23
api/controllers/openapi/_meta.py
Normal file
23
api/controllers/openapi/_meta.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
|
||||||
|
|
||||||
|
Returns the server's project version and edition so the difyctl CLI can probe
|
||||||
|
compatibility without needing to be logged in. Mirrors the `_health` endpoint
|
||||||
|
in `index.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from flask_restx import Resource
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._models import ServerVersionResponse
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/_version")
|
||||||
|
class VersionApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||||
|
def get(self):
|
||||||
|
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||||
|
return ServerVersionResponse(
|
||||||
|
version=dify_config.project.version,
|
||||||
|
edition=edition,
|
||||||
|
).model_dump(mode="json")
|
||||||
326
api/controllers/openapi/_models.py
Normal file
326
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
"""Shared response substructures for openapi endpoints."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||||
|
MAX_PAGE_LIMIT = 200
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class MessageMetadata(BaseModel):
|
||||||
|
usage: UsageInfo | None = None
|
||||||
|
retriever_resources: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationEnvelope[T](BaseModel):
|
||||||
|
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||||
|
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[T]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||||
|
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||||
|
|
||||||
|
|
||||||
|
class TagItem(BaseModel):
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class AppListRow(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
mode: AppMode
|
||||||
|
tags: list[TagItem] = []
|
||||||
|
updated_at: str | None = None
|
||||||
|
created_by_name: str | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
|
workspace_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AppListResponse(BaseModel):
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[AppListRow]
|
||||||
|
|
||||||
|
|
||||||
|
class PermittedExternalAppsListResponse(BaseModel):
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[AppListRow]
|
||||||
|
|
||||||
|
|
||||||
|
class AppInfoResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
mode: str
|
||||||
|
author: str | None = None
|
||||||
|
tags: list[TagItem] = []
|
||||||
|
|
||||||
|
|
||||||
|
class AppDescribeInfo(AppInfoResponse):
|
||||||
|
updated_at: str | None = None
|
||||||
|
service_api_enabled: bool
|
||||||
|
is_agent: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AppDescribeResponse(BaseModel):
|
||||||
|
info: AppDescribeInfo | None = None
|
||||||
|
parameters: dict[str, Any] | None = None
|
||||||
|
input_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessageResponse(BaseModel):
|
||||||
|
event: str
|
||||||
|
task_id: str
|
||||||
|
id: str
|
||||||
|
message_id: str
|
||||||
|
conversation_id: str
|
||||||
|
mode: str
|
||||||
|
answer: str
|
||||||
|
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionMessageResponse(BaseModel):
|
||||||
|
event: str
|
||||||
|
task_id: str
|
||||||
|
id: str
|
||||||
|
message_id: str
|
||||||
|
mode: str
|
||||||
|
answer: str
|
||||||
|
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunData(BaseModel):
|
||||||
|
id: str
|
||||||
|
workflow_id: str
|
||||||
|
status: str
|
||||||
|
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
error: str | None = None
|
||||||
|
elapsed_time: float | None = None
|
||||||
|
total_tokens: int | None = None
|
||||||
|
total_steps: int | None = None
|
||||||
|
created_at: int | None = None
|
||||||
|
finished_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunResponse(BaseModel):
|
||||||
|
workflow_run_id: str
|
||||||
|
task_id: str
|
||||||
|
mode: Literal["workflow"] = "workflow"
|
||||||
|
data: WorkflowRunData
|
||||||
|
|
||||||
|
|
||||||
|
class AccountPayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspacePayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
|
||||||
|
|
||||||
|
class AccountResponse(BaseModel):
|
||||||
|
subject_type: str
|
||||||
|
subject_email: str | None = None
|
||||||
|
subject_issuer: str | None = None
|
||||||
|
account: AccountPayload | None = None
|
||||||
|
workspaces: list[WorkspacePayload] = []
|
||||||
|
default_workspace_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRow(BaseModel):
|
||||||
|
id: str
|
||||||
|
prefix: str
|
||||||
|
client_id: str
|
||||||
|
device_label: str
|
||||||
|
created_at: str | None = None
|
||||||
|
last_used_at: str | None = None
|
||||||
|
expires_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SessionListResponse(BaseModel):
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
data: list[SessionRow]
|
||||||
|
|
||||||
|
|
||||||
|
class RevokeResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceSummaryResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
status: str
|
||||||
|
current: bool
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceListResponse(BaseModel):
|
||||||
|
workspaces: list[WorkspaceSummaryResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceDetailResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
role: str
|
||||||
|
status: str
|
||||||
|
current: bool
|
||||||
|
created_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceCodeResponse(BaseModel):
|
||||||
|
device_code: str
|
||||||
|
user_code: str
|
||||||
|
verification_uri: str
|
||||||
|
expires_in: int
|
||||||
|
interval: int
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceLookupResponse(BaseModel):
|
||||||
|
valid: bool
|
||||||
|
expires_in_remaining: int = 0
|
||||||
|
client_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceMutateResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class ServerVersionResponse(BaseModel):
|
||||||
|
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
|
||||||
|
|
||||||
|
version: str
|
||||||
|
edition: Literal["SELF_HOSTED", "CLOUD"]
|
||||||
|
|
||||||
|
|
||||||
|
class AppDescribeQuery(BaseModel):
|
||||||
|
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||||
|
|
||||||
|
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
fields: set[str] | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
|
|
||||||
|
@field_validator("workspace_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||||
|
if v is None or v == "":
|
||||||
|
return None
|
||||||
|
if not isinstance(v, str):
|
||||||
|
raise ValueError("workspace_id must be a string")
|
||||||
|
try:
|
||||||
|
import uuid as _uuid
|
||||||
|
|
||||||
|
_uuid.UUID(v)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("workspace_id must be a valid UUID")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("fields", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||||
|
if v is None or v == "":
|
||||||
|
return None
|
||||||
|
if not isinstance(v, str):
|
||||||
|
raise ValueError("fields must be a comma-separated string")
|
||||||
|
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||||
|
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||||
|
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||||
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
class AppListQuery(BaseModel):
|
||||||
|
"""mode is a closed enum."""
|
||||||
|
|
||||||
|
workspace_id: str
|
||||||
|
page: int = Field(1, ge=1)
|
||||||
|
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||||
|
mode: AppMode | None = None
|
||||||
|
name: str | None = Field(None, max_length=200)
|
||||||
|
tag: str | None = Field(None, max_length=100)
|
||||||
|
|
||||||
|
|
||||||
|
class AppRunRequest(BaseModel):
|
||||||
|
inputs: dict[str, Any]
|
||||||
|
query: str | None = None
|
||||||
|
files: list[dict[str, Any]] | None = None
|
||||||
|
conversation_id: UUIDStrOrEmpty | None = None
|
||||||
|
auto_generate_name: bool = True
|
||||||
|
workflow_id: str | None = None
|
||||||
|
workspace_id: UUIDStrOrEmpty | None = None
|
||||||
|
|
||||||
|
@field_validator("conversation_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||||
|
if isinstance(value, str):
|
||||||
|
value = value.strip()
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return uuid_value(value)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceCodeRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
device_label: str
|
||||||
|
|
||||||
|
|
||||||
|
class DevicePollRequest(BaseModel):
|
||||||
|
device_code: str
|
||||||
|
client_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceLookupQuery(BaseModel):
|
||||||
|
user_code: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceMutateRequest(BaseModel):
|
||||||
|
user_code: str
|
||||||
|
|
||||||
|
|
||||||
|
class PermittedExternalAppsListQuery(BaseModel):
|
||||||
|
"""Strict (extra='forbid')."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
page: int = Field(1, ge=1)
|
||||||
|
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||||
|
mode: AppMode | None = None
|
||||||
|
name: str | None = Field(None, max_length=200)
|
||||||
169
api/controllers/openapi/account.py
Normal file
169
api/controllers/openapi/account.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._models import (
|
||||||
|
MAX_PAGE_LIMIT,
|
||||||
|
AccountPayload,
|
||||||
|
AccountResponse,
|
||||||
|
PaginationEnvelope,
|
||||||
|
RevokeResponse,
|
||||||
|
SessionListResponse,
|
||||||
|
SessionRow,
|
||||||
|
WorkspacePayload,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.oauth_bearer import (
|
||||||
|
ACCEPT_USER_ANY,
|
||||||
|
AuthContext,
|
||||||
|
SubjectType,
|
||||||
|
get_auth_ctx,
|
||||||
|
validate_bearer,
|
||||||
|
)
|
||||||
|
from libs.rate_limit import (
|
||||||
|
LIMIT_ME_PER_ACCOUNT,
|
||||||
|
LIMIT_ME_PER_EMAIL,
|
||||||
|
enforce,
|
||||||
|
)
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.oauth_device_flow import (
|
||||||
|
list_active_sessions,
|
||||||
|
revoke_oauth_token,
|
||||||
|
token_belongs_to_subject,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/account")
|
||||||
|
class AccountApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
def get(self):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
|
||||||
|
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||||
|
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||||
|
else:
|
||||||
|
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||||
|
|
||||||
|
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||||
|
return AccountResponse(
|
||||||
|
subject_type=ctx.subject_type,
|
||||||
|
subject_email=ctx.subject_email,
|
||||||
|
subject_issuer=ctx.subject_issuer,
|
||||||
|
account=None,
|
||||||
|
workspaces=[],
|
||||||
|
default_workspace_id=None,
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
|
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||||
|
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||||
|
default_ws_id = _pick_default_workspace(memberships)
|
||||||
|
|
||||||
|
return AccountResponse(
|
||||||
|
subject_type=ctx.subject_type,
|
||||||
|
subject_email=ctx.subject_email or (account.email if account else None),
|
||||||
|
account=_account_payload(account) if account else None,
|
||||||
|
workspaces=[_workspace_payload(m) for m in memberships],
|
||||||
|
default_workspace_id=default_ws_id,
|
||||||
|
).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/account/sessions/self")
|
||||||
|
class AccountSessionsSelfApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
def delete(self):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
_require_oauth_subject(ctx)
|
||||||
|
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||||
|
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/account/sessions")
|
||||||
|
class AccountSessionsApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
def get(self):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
page = int(request.args.get("page", "1"))
|
||||||
|
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||||
|
|
||||||
|
all_rows = list_active_sessions(db.session, ctx, now)
|
||||||
|
|
||||||
|
total = len(all_rows)
|
||||||
|
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||||
|
|
||||||
|
items = [
|
||||||
|
SessionRow(
|
||||||
|
id=str(r.id),
|
||||||
|
prefix=r.prefix,
|
||||||
|
client_id=r.client_id,
|
||||||
|
device_label=r.device_label,
|
||||||
|
created_at=_iso(r.created_at),
|
||||||
|
last_used_at=_iso(r.last_used_at),
|
||||||
|
expires_at=_iso(r.expires_at),
|
||||||
|
)
|
||||||
|
for r in sliced
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||||
|
class AccountSessionByIdApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
def delete(self, session_id: str):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
_require_oauth_subject(ctx)
|
||||||
|
|
||||||
|
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||||
|
# token IDs that belong to other subjects.
|
||||||
|
if not token_belongs_to_subject(db.session, session_id, ctx):
|
||||||
|
raise NotFound("session not found")
|
||||||
|
|
||||||
|
revoke_oauth_token(db.session, redis_client, session_id)
|
||||||
|
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||||
|
if not ctx.source.startswith("oauth"):
|
||||||
|
raise BadRequest(
|
||||||
|
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _iso(dt: datetime | None) -> str | None:
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=UTC)
|
||||||
|
return dt.isoformat().replace("+00:00", "Z")
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_default_workspace(memberships) -> str | None:
|
||||||
|
if not memberships:
|
||||||
|
return None
|
||||||
|
for join, tenant in memberships:
|
||||||
|
if getattr(join, "current", False):
|
||||||
|
return str(tenant.id)
|
||||||
|
return str(memberships[0][1].id)
|
||||||
|
|
||||||
|
|
||||||
|
def _workspace_payload(row) -> WorkspacePayload:
|
||||||
|
join, tenant = row
|
||||||
|
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||||
|
|
||||||
|
|
||||||
|
def _account_payload(account) -> AccountPayload:
|
||||||
|
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||||
165
api/controllers/openapi/app_run.py
Normal file
165
api/controllers/openapi/app_run.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable, Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._audit import emit_app_run
|
||||||
|
from controllers.openapi._models import AppRunRequest
|
||||||
|
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||||
|
from controllers.service_api.app.error import (
|
||||||
|
AppUnavailableError,
|
||||||
|
CompletionRequestError,
|
||||||
|
ConversationCompletedError,
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
)
|
||||||
|
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.errors.error import (
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from graphon.graph_engine.manager import GraphEngineManager
|
||||||
|
from graphon.model_runtime.errors.invoke import InvokeError
|
||||||
|
from libs import helper
|
||||||
|
from libs.oauth_bearer import Scope
|
||||||
|
from models.model import App, AppMode
|
||||||
|
from services.app_generate_service import AppGenerateService
|
||||||
|
from services.errors.app import (
|
||||||
|
IsDraftWorkflowError,
|
||||||
|
WorkflowIdFormatError,
|
||||||
|
WorkflowNotFoundError,
|
||||||
|
)
|
||||||
|
from services.errors.llm import InvokeRateLimitError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _translate_service_errors() -> Iterator[None]:
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except WorkflowNotFoundError as ex:
|
||||||
|
raise NotFound(str(ex))
|
||||||
|
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||||
|
raise BadRequest(str(ex))
|
||||||
|
except services.errors.conversation.ConversationNotExistsError:
|
||||||
|
raise NotFound("Conversation Not Exists.")
|
||||||
|
except services.errors.conversation.ConversationCompletedError:
|
||||||
|
raise ConversationCompletedError()
|
||||||
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
|
logger.exception("App model config broken.")
|
||||||
|
raise AppUnavailableError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except InvokeRateLimitError as ex:
|
||||||
|
raise InvokeRateLimitHttpError(ex.description)
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||||
|
return AppGenerateService.generate(
|
||||||
|
app_model=app,
|
||||||
|
user=caller,
|
||||||
|
args=args,
|
||||||
|
invoke_from=InvokeFrom.OPENAPI,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
||||||
|
if not payload.query or not payload.query.strip():
|
||||||
|
raise UnprocessableEntity("query_required_for_chat")
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
|
with _translate_service_errors():
|
||||||
|
return _generate(app, caller, args, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
||||||
|
args = payload.model_dump(exclude_none=True)
|
||||||
|
args["auto_generate_name"] = False
|
||||||
|
args.setdefault("query", "")
|
||||||
|
with _translate_service_errors():
|
||||||
|
return _generate(app, caller, args, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
||||||
|
if payload.query is not None:
|
||||||
|
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||||
|
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||||
|
with _translate_service_errors():
|
||||||
|
return _generate(app, caller, args, streaming=True)
|
||||||
|
|
||||||
|
|
||||||
|
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||||
|
AppMode.CHAT: _run_chat,
|
||||||
|
AppMode.AGENT_CHAT: _run_chat,
|
||||||
|
AppMode.ADVANCED_CHAT: _run_chat,
|
||||||
|
AppMode.COMPLETION: _run_completion,
|
||||||
|
AppMode.WORKFLOW: _run_workflow,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||||
|
class AppRunApi(Resource):
|
||||||
|
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||||
|
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||||
|
body = request.get_json(silent=True) or {}
|
||||||
|
try:
|
||||||
|
payload = AppRunRequest.model_validate(body)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise UnprocessableEntity(exc.json())
|
||||||
|
|
||||||
|
handler = _DISPATCH.get(app_model.mode)
|
||||||
|
if handler is None:
|
||||||
|
raise UnprocessableEntity("mode_not_runnable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream_obj = handler(app_model, caller, payload)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("internal server error.")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
emit_app_run(
|
||||||
|
app_id=app_model.id,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
caller_kind=caller_kind,
|
||||||
|
mode=str(app_model.mode),
|
||||||
|
surface="apps",
|
||||||
|
)
|
||||||
|
|
||||||
|
return helper.compact_generate_response(stream_obj)
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||||
|
class AppRunTaskStopApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Task stopped")
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||||
|
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||||
|
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||||
|
return {"result": "success"}
|
||||||
270
api/controllers/openapi/apps.py
Normal file
270
api/controllers/openapi/apps.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
"""GET /openapi/v1/apps and per-app reads.
|
||||||
|
|
||||||
|
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||||
|
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||||
|
reads it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid as _uuid
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||||
|
|
||||||
|
from controllers.common.fields import Parameters
|
||||||
|
from controllers.common.schema import query_params_from_model
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||||
|
from controllers.openapi._models import (
|
||||||
|
AppDescribeInfo,
|
||||||
|
AppDescribeQuery,
|
||||||
|
AppDescribeResponse,
|
||||||
|
AppListQuery,
|
||||||
|
AppListResponse,
|
||||||
|
AppListRow,
|
||||||
|
TagItem,
|
||||||
|
)
|
||||||
|
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||||
|
from controllers.service_api.app.error import AppUnavailableError
|
||||||
|
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.oauth_bearer import (
|
||||||
|
ACCEPT_USER_ANY,
|
||||||
|
AuthContext,
|
||||||
|
Scope,
|
||||||
|
SubjectType,
|
||||||
|
get_auth_ctx,
|
||||||
|
require_scope,
|
||||||
|
require_workspace_member,
|
||||||
|
validate_bearer,
|
||||||
|
)
|
||||||
|
from models import App
|
||||||
|
from services.account_service import TenantService
|
||||||
|
from services.app_service import AppListParams, AppService
|
||||||
|
from services.tag_service import TagService
|
||||||
|
|
||||||
|
_APPS_READ_DECORATORS = [
|
||||||
|
require_scope(Scope.APPS_READ),
|
||||||
|
accept_subjects(SubjectType.ACCOUNT),
|
||||||
|
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||||
|
]
|
||||||
|
|
||||||
|
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||||
|
|
||||||
|
|
||||||
|
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||||
|
"opening_statement": None,
|
||||||
|
"suggested_questions": [],
|
||||||
|
"user_input_form": [],
|
||||||
|
"file_upload": None,
|
||||||
|
"system_parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AppReadResource(Resource):
|
||||||
|
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||||
|
|
||||||
|
method_decorators = _APPS_READ_DECORATORS
|
||||||
|
|
||||||
|
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||||
|
ctx: AuthContext = get_auth_ctx()
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_uuid = _uuid.UUID(app_id)
|
||||||
|
is_uuid = True
|
||||||
|
except ValueError:
|
||||||
|
parsed_uuid = None
|
||||||
|
is_uuid = False
|
||||||
|
|
||||||
|
if is_uuid:
|
||||||
|
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
|
||||||
|
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||||
|
if app is None:
|
||||||
|
raise NotFound("app not found")
|
||||||
|
else:
|
||||||
|
if not workspace_id:
|
||||||
|
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||||
|
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
|
||||||
|
if len(matches) == 0:
|
||||||
|
raise NotFound("app not found")
|
||||||
|
if len(matches) > 1:
|
||||||
|
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||||
|
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||||
|
for m in matches:
|
||||||
|
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||||
|
raise Conflict("".join(lines))
|
||||||
|
app = matches[0]
|
||||||
|
|
||||||
|
require_workspace_member(ctx, str(app.tenant_id))
|
||||||
|
return app, ctx
|
||||||
|
|
||||||
|
|
||||||
|
def parameters_payload(app: App) -> dict:
|
||||||
|
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||||
|
features_dict, user_input_form = resolve_app_config(app)
|
||||||
|
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||||
|
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||||
|
class AppDescribeApi(AppReadResource):
|
||||||
|
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||||
|
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||||
|
def get(self, app_id: str):
|
||||||
|
try:
|
||||||
|
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise UnprocessableEntity(exc.json())
|
||||||
|
|
||||||
|
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||||
|
|
||||||
|
requested = query.fields
|
||||||
|
want_info = requested is None or "info" in requested
|
||||||
|
want_params = requested is None or "parameters" in requested
|
||||||
|
want_schema = requested is None or "input_schema" in requested
|
||||||
|
|
||||||
|
info = (
|
||||||
|
AppDescribeInfo(
|
||||||
|
id=str(app.id),
|
||||||
|
name=app.name,
|
||||||
|
mode=app.mode,
|
||||||
|
description=app.description,
|
||||||
|
tags=[TagItem(name=t.name) for t in app.tags],
|
||||||
|
author=app.author_name,
|
||||||
|
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||||
|
service_api_enabled=bool(app.enable_api),
|
||||||
|
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||||
|
)
|
||||||
|
if want_info
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters: dict[str, Any] | None = None
|
||||||
|
input_schema: dict[str, Any] | None = None
|
||||||
|
if want_params:
|
||||||
|
try:
|
||||||
|
parameters = parameters_payload(app)
|
||||||
|
except AppUnavailableError:
|
||||||
|
parameters = dict(_EMPTY_PARAMETERS)
|
||||||
|
if want_schema:
|
||||||
|
try:
|
||||||
|
input_schema = build_input_schema(app)
|
||||||
|
except AppUnavailableError:
|
||||||
|
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||||
|
|
||||||
|
return (
|
||||||
|
AppDescribeResponse(
|
||||||
|
info=info,
|
||||||
|
parameters=parameters,
|
||||||
|
input_schema=input_schema,
|
||||||
|
).model_dump(mode="json", exclude_none=False),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps")
|
||||||
|
class AppListApi(Resource):
|
||||||
|
method_decorators = _APPS_READ_DECORATORS
|
||||||
|
|
||||||
|
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||||
|
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||||
|
def get(self):
|
||||||
|
ctx: AuthContext = get_auth_ctx()
|
||||||
|
|
||||||
|
try:
|
||||||
|
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise UnprocessableEntity(exc.json())
|
||||||
|
|
||||||
|
workspace_id = query.workspace_id
|
||||||
|
require_workspace_member(ctx, workspace_id)
|
||||||
|
|
||||||
|
empty = (
|
||||||
|
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||||
|
mode="json"
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
if query.name:
|
||||||
|
try:
|
||||||
|
parsed_uuid = _uuid.UUID(query.name)
|
||||||
|
except ValueError:
|
||||||
|
parsed_uuid = None
|
||||||
|
else:
|
||||||
|
parsed_uuid = None
|
||||||
|
|
||||||
|
tenant_name: str | None = None
|
||||||
|
if parsed_uuid is not None:
|
||||||
|
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||||
|
if app is None or str(app.tenant_id) != workspace_id:
|
||||||
|
return empty
|
||||||
|
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||||
|
item = AppListRow(
|
||||||
|
id=str(app.id),
|
||||||
|
name=app.name,
|
||||||
|
description=app.description,
|
||||||
|
mode=app.mode,
|
||||||
|
tags=[TagItem(name=t.name) for t in app.tags],
|
||||||
|
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||||
|
created_by_name=getattr(app, "author_name", None),
|
||||||
|
workspace_id=str(workspace_id),
|
||||||
|
workspace_name=tenant_name,
|
||||||
|
)
|
||||||
|
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||||
|
return env.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
tag_ids: list[str] | None = None
|
||||||
|
if query.tag:
|
||||||
|
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||||
|
if not tags:
|
||||||
|
return empty
|
||||||
|
tag_ids = [tag.id for tag in tags]
|
||||||
|
|
||||||
|
params = AppListParams(
|
||||||
|
page=query.page,
|
||||||
|
limit=query.limit,
|
||||||
|
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||||
|
name=query.name,
|
||||||
|
tag_ids=tag_ids,
|
||||||
|
status="normal",
|
||||||
|
# Visibility gate pushed into the query — pagination.total stays
|
||||||
|
# consistent across pages because invisible rows never count.
|
||||||
|
openapi_visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||||
|
if pagination is None:
|
||||||
|
return empty
|
||||||
|
|
||||||
|
tenant_name = None
|
||||||
|
if pagination.items:
|
||||||
|
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
AppListRow(
|
||||||
|
id=str(r.id),
|
||||||
|
name=r.name,
|
||||||
|
description=r.description,
|
||||||
|
mode=r.mode,
|
||||||
|
tags=[TagItem(name=t.name) for t in r.tags],
|
||||||
|
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||||
|
created_by_name=getattr(r, "author_name", None),
|
||||||
|
workspace_id=str(workspace_id),
|
||||||
|
workspace_name=tenant_name,
|
||||||
|
)
|
||||||
|
for r in pagination.items
|
||||||
|
]
|
||||||
|
|
||||||
|
env = AppListResponse(
|
||||||
|
page=query.page,
|
||||||
|
limit=query.limit,
|
||||||
|
total=cast(int, pagination.total),
|
||||||
|
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||||
|
data=items,
|
||||||
|
)
|
||||||
|
return env.model_dump(mode="json"), 200
|
||||||
102
api/controllers/openapi/apps_permitted_external.py
Normal file
102
api/controllers/openapi/apps_permitted_external.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||||
|
|
||||||
|
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||||
|
(public / sso_verified). License-gated: CE deploys never enable the
|
||||||
|
EE blueprint chain so this module is unreachable there.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from werkzeug.exceptions import UnprocessableEntity
|
||||||
|
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._models import (
|
||||||
|
AppListRow,
|
||||||
|
PermittedExternalAppsListQuery,
|
||||||
|
PermittedExternalAppsListResponse,
|
||||||
|
)
|
||||||
|
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.device_flow_security import enterprise_only
|
||||||
|
from libs.oauth_bearer import (
|
||||||
|
ACCEPT_USER_ANY,
|
||||||
|
Scope,
|
||||||
|
SubjectType,
|
||||||
|
require_scope,
|
||||||
|
validate_bearer,
|
||||||
|
)
|
||||||
|
from models import App
|
||||||
|
from services.account_service import TenantService
|
||||||
|
from services.app_service import AppService
|
||||||
|
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||||
|
from services.openapi.license_gate import license_required
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/permitted-external-apps")
|
||||||
|
class PermittedExternalAppsListApi(Resource):
|
||||||
|
method_decorators = [
|
||||||
|
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||||
|
license_required,
|
||||||
|
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||||
|
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||||
|
enterprise_only,
|
||||||
|
]
|
||||||
|
|
||||||
|
@openapi_ns.response(
|
||||||
|
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
try:
|
||||||
|
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise UnprocessableEntity(exc.json())
|
||||||
|
|
||||||
|
page_result = list_permitted_apps(
|
||||||
|
page=query.page,
|
||||||
|
limit=query.limit,
|
||||||
|
mode=query.mode.value if query.mode else None,
|
||||||
|
name=query.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not page_result.app_ids:
|
||||||
|
env = PermittedExternalAppsListResponse(
|
||||||
|
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||||
|
)
|
||||||
|
return env.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
apps_by_id: dict[str, App] = {
|
||||||
|
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||||
|
}
|
||||||
|
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
|
||||||
|
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
|
||||||
|
|
||||||
|
items: list[AppListRow] = []
|
||||||
|
for app_id in page_result.app_ids:
|
||||||
|
app = apps_by_id.get(app_id)
|
||||||
|
if not app or app.status != "normal":
|
||||||
|
continue
|
||||||
|
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||||
|
items.append(
|
||||||
|
AppListRow(
|
||||||
|
id=str(app.id),
|
||||||
|
name=app.name,
|
||||||
|
description=app.description,
|
||||||
|
mode=app.mode,
|
||||||
|
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||||
|
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||||
|
created_by_name=None, # cross-tenant author leak prevention
|
||||||
|
workspace_id=str(app.tenant_id),
|
||||||
|
workspace_name=tenant.name if tenant else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
env = PermittedExternalAppsListResponse(
|
||||||
|
page=query.page,
|
||||||
|
limit=query.limit,
|
||||||
|
total=page_result.total,
|
||||||
|
has_more=query.page * query.limit < page_result.total,
|
||||||
|
data=items,
|
||||||
|
)
|
||||||
|
return env.model_dump(mode="json"), 200
|
||||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||||
|
|
||||||
|
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||||
46
api/controllers/openapi/auth/composition.py
Normal file
46
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||||
|
|
||||||
|
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||||
|
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||||
|
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||||
|
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from controllers.openapi.auth.pipeline import Pipeline
|
||||||
|
from controllers.openapi.auth.steps import (
|
||||||
|
AppAuthzCheck,
|
||||||
|
AppResolver,
|
||||||
|
BearerCheck,
|
||||||
|
CallerMount,
|
||||||
|
ScopeCheck,
|
||||||
|
SurfaceCheck,
|
||||||
|
WorkspaceMembershipCheck,
|
||||||
|
)
|
||||||
|
from controllers.openapi.auth.strategies import (
|
||||||
|
AccountMounter,
|
||||||
|
AclStrategy,
|
||||||
|
AppAuthzStrategy,
|
||||||
|
EndUserMounter,
|
||||||
|
MembershipStrategy,
|
||||||
|
)
|
||||||
|
from libs.oauth_bearer import SubjectType
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||||
|
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
return AclStrategy()
|
||||||
|
return MembershipStrategy()
|
||||||
|
|
||||||
|
|
||||||
|
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||||
|
BearerCheck(),
|
||||||
|
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||||
|
ScopeCheck(),
|
||||||
|
AppResolver(),
|
||||||
|
WorkspaceMembershipCheck(),
|
||||||
|
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||||
|
CallerMount(AccountMounter(), EndUserMounter()),
|
||||||
|
)
|
||||||
68
api/controllers/openapi/auth/context.py
Normal file
68
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""Mutable per-request context for the openapi auth pipeline.
|
||||||
|
|
||||||
|
Every field starts None / empty and is filled in by a step. The pipeline
|
||||||
|
is the only thing that should construct or mutate Context — handlers
|
||||||
|
read populated values via the decorator's kwargs unpacking.
|
||||||
|
|
||||||
|
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||||
|
guard extracts whatever transport-level inputs the steps need (bearer
|
||||||
|
token, path params) at the boundary and writes them into Context fields,
|
||||||
|
so steps stay testable without a request object and won't leak coupling
|
||||||
|
to a specific framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from contextvars import Token
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Literal, Protocol
|
||||||
|
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from models import App, Tenant
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Context:
|
||||||
|
required_scope: Scope
|
||||||
|
bearer_token: str | None = None
|
||||||
|
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||||
|
subject_type: SubjectType | None = None
|
||||||
|
subject_email: str | None = None
|
||||||
|
subject_issuer: str | None = None
|
||||||
|
account_id: uuid.UUID | None = None
|
||||||
|
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||||
|
token_id: uuid.UUID | None = None
|
||||||
|
token_hash: str | None = None
|
||||||
|
cached_verified_tenants: dict[str, bool] | None = None
|
||||||
|
source: str | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
app: App | None = None
|
||||||
|
tenant: Tenant | None = None
|
||||||
|
caller: object | None = None
|
||||||
|
caller_kind: Literal["account", "end_user"] | None = None
|
||||||
|
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def must_tenant(self) -> Tenant:
|
||||||
|
if not self.tenant:
|
||||||
|
raise Unauthorized("tenant is not associated")
|
||||||
|
return self.tenant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def must_subject_type(self) -> SubjectType:
|
||||||
|
if not self.subject_type:
|
||||||
|
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||||
|
return self.subject_type
|
||||||
|
|
||||||
|
|
||||||
|
class Step(Protocol):
|
||||||
|
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None: ...
|
||||||
51
api/controllers/openapi/auth/pipeline.py
Normal file
51
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
"""Pipeline IS the auth scheme.
|
||||||
|
|
||||||
|
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||||
|
that is the design lock-in: forgetting an auth layer is structurally
|
||||||
|
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
from controllers.openapi.auth.context import Context, Step
|
||||||
|
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
def __init__(self, *steps: Step) -> None:
|
||||||
|
self._steps = steps
|
||||||
|
|
||||||
|
def run(self, ctx: Context) -> None:
|
||||||
|
for step in self._steps:
|
||||||
|
step(ctx)
|
||||||
|
|
||||||
|
def guard(self, *, scope: Scope):
|
||||||
|
def decorator(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
# Extract transport-level inputs at the boundary so steps
|
||||||
|
# stay decoupled from Flask's request object.
|
||||||
|
ctx = Context(
|
||||||
|
required_scope=scope,
|
||||||
|
bearer_token=extract_bearer(request),
|
||||||
|
path_params=dict(request.view_args or {}),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.run(ctx)
|
||||||
|
kwargs.update(
|
||||||
|
app_model=ctx.app,
|
||||||
|
caller=ctx.caller,
|
||||||
|
caller_kind=ctx.caller_kind,
|
||||||
|
)
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
if ctx.auth_ctx_reset_token is not None:
|
||||||
|
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
170
api/controllers/openapi/auth/steps.py
Normal file
170
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
"""Pipeline steps. Each is one responsibility.
|
||||||
|
|
||||||
|
`BearerCheck` is the only step that touches the token registry; downstream
|
||||||
|
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||||
|
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||||
|
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||||
|
surface gate and any handler reading the request-scoped context has a single
|
||||||
|
source of truth across both auth-attach paths. The reset token is stashed
|
||||||
|
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||||
|
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.openapi.auth.context import Context
|
||||||
|
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||||
|
from controllers.openapi.auth.surface_gate import check_surface
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.oauth_bearer import (
|
||||||
|
AuthContext,
|
||||||
|
InvalidBearerError,
|
||||||
|
Scope,
|
||||||
|
SubjectType,
|
||||||
|
check_workspace_membership,
|
||||||
|
get_authenticator,
|
||||||
|
set_auth_ctx,
|
||||||
|
)
|
||||||
|
from models import TenantStatus
|
||||||
|
from services.account_service import TenantService
|
||||||
|
from services.app_service import AppService
|
||||||
|
|
||||||
|
|
||||||
|
class BearerCheck:
|
||||||
|
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||||
|
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||||
|
Also publishes the resolved `AuthContext` via
|
||||||
|
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||||
|
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||||
|
don't see two different identity sources. The reset token is parked on
|
||||||
|
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
if not ctx.bearer_token:
|
||||||
|
raise Unauthorized("bearer required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||||
|
except InvalidBearerError as e:
|
||||||
|
raise Unauthorized(str(e))
|
||||||
|
|
||||||
|
ctx.subject_type = authn.subject_type
|
||||||
|
ctx.subject_email = authn.subject_email
|
||||||
|
ctx.subject_issuer = authn.subject_issuer
|
||||||
|
ctx.account_id = authn.account_id
|
||||||
|
ctx.scopes = frozenset(authn.scopes)
|
||||||
|
ctx.source = authn.source
|
||||||
|
ctx.token_id = authn.token_id
|
||||||
|
ctx.expires_at = authn.expires_at
|
||||||
|
ctx.token_hash = authn.token_hash
|
||||||
|
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||||
|
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||||
|
|
||||||
|
|
||||||
|
class ScopeCheck:
|
||||||
|
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||||
|
return
|
||||||
|
raise Forbidden("insufficient_scope")
|
||||||
|
|
||||||
|
|
||||||
|
class SurfaceCheck:
|
||||||
|
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||||
|
|
||||||
|
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||||
|
self._accepted = accepted
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
check_surface(self._accepted)
|
||||||
|
|
||||||
|
|
||||||
|
class AppResolver:
|
||||||
|
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||||
|
|
||||||
|
Every endpoint using the OAuth bearer pipeline must declare
|
||||||
|
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||||
|
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||||
|
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||||
|
about the request object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
app_id = ctx.path_params.get("app_id")
|
||||||
|
if not app_id:
|
||||||
|
raise BadRequest("app_id is required in path")
|
||||||
|
app = AppService.get_app_by_id(db.session, app_id)
|
||||||
|
if not app or app.status != "normal":
|
||||||
|
raise NotFound("app not found")
|
||||||
|
if not app.enable_api:
|
||||||
|
raise Forbidden("service_api_disabled")
|
||||||
|
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||||
|
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||||
|
raise Forbidden("workspace unavailable")
|
||||||
|
ctx.app, ctx.tenant = app, tenant
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceMembershipCheck:
|
||||||
|
"""Layer 0 — workspace membership gate.
|
||||||
|
|
||||||
|
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||||
|
(dfoa_) only — SSO subjects skip.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
|
return
|
||||||
|
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||||
|
return
|
||||||
|
if ctx.account_id is None or ctx.tenant is None:
|
||||||
|
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||||
|
if ctx.token_hash is None:
|
||||||
|
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||||
|
|
||||||
|
check_workspace_membership(
|
||||||
|
account_id=ctx.account_id,
|
||||||
|
tenant_id=ctx.must_tenant.id,
|
||||||
|
token_hash=ctx.token_hash,
|
||||||
|
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AppAuthzCheck:
|
||||||
|
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||||
|
self._resolve = resolve_strategy
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
if not self._resolve().authorize(ctx):
|
||||||
|
raise Forbidden("subject_no_app_access")
|
||||||
|
|
||||||
|
|
||||||
|
class CallerMount:
|
||||||
|
def __init__(self, *mounters: CallerMounter) -> None:
|
||||||
|
self._mounters = mounters
|
||||||
|
|
||||||
|
def __call__(self, ctx: Context) -> None:
|
||||||
|
if ctx.subject_type is None:
|
||||||
|
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||||
|
for m in self._mounters:
|
||||||
|
if m.applies_to(ctx.must_subject_type):
|
||||||
|
m.mount(ctx)
|
||||||
|
return
|
||||||
|
raise Unauthorized("no caller mounter for subject type")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AppAuthzCheck",
|
||||||
|
"AppResolver",
|
||||||
|
"AuthContext",
|
||||||
|
"BearerCheck",
|
||||||
|
"CallerMount",
|
||||||
|
"ScopeCheck",
|
||||||
|
"SurfaceCheck",
|
||||||
|
"WorkspaceMembershipCheck",
|
||||||
|
]
|
||||||
168
api/controllers/openapi/auth/strategies.py
Normal file
168
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
"""Strategy classes for the openapi auth pipeline.
|
||||||
|
|
||||||
|
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||||
|
vary along independent axes; each strategy is one class so the pipeline
|
||||||
|
composition stays a flat list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
from flask_login import user_logged_in
|
||||||
|
|
||||||
|
from controllers.openapi.auth.context import Context
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.oauth_bearer import SubjectType
|
||||||
|
from services.account_service import AccountService, TenantService
|
||||||
|
from services.end_user_service import EndUserService
|
||||||
|
from services.enterprise.enterprise_service import (
|
||||||
|
EnterpriseService,
|
||||||
|
WebAppAccessMode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AppAuthzStrategy(Protocol):
|
||||||
|
def authorize(self, ctx: Context) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AclStrategy:
|
||||||
|
"""Per-app ACL, evaluated in two stages.
|
||||||
|
|
||||||
|
The EE gateway has already enforced tenancy and workspace membership
|
||||||
|
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||||
|
|
||||||
|
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||||
|
bearers belong to public-facing apps only; account bearers cover the
|
||||||
|
full set. A mismatch is an immediate deny — no IO.
|
||||||
|
2. For modes that pair with the subject, decide whether the inner
|
||||||
|
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||||
|
requires it; the remaining modes are pass-through.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||||
|
SubjectType.ACCOUNT: frozenset(
|
||||||
|
{
|
||||||
|
WebAppAccessMode.PUBLIC,
|
||||||
|
WebAppAccessMode.SSO_VERIFIED,
|
||||||
|
WebAppAccessMode.PRIVATE_ALL,
|
||||||
|
WebAppAccessMode.PRIVATE,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
SubjectType.EXTERNAL_SSO: frozenset(
|
||||||
|
{
|
||||||
|
WebAppAccessMode.PUBLIC,
|
||||||
|
WebAppAccessMode.SSO_VERIFIED,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||||
|
|
||||||
|
def authorize(self, ctx: Context) -> bool:
|
||||||
|
if ctx.app is None:
|
||||||
|
return False
|
||||||
|
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||||
|
if access_mode is None:
|
||||||
|
return False
|
||||||
|
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||||
|
return False
|
||||||
|
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||||
|
return True
|
||||||
|
return self._inner_permission_check(ctx)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||||
|
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||||
|
if settings is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return WebAppAccessMode(settings.access_mode)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||||
|
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||||
|
|
||||||
|
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||||
|
if ctx.app is None:
|
||||||
|
return False
|
||||||
|
user_id = self._resolve_user_id(ctx)
|
||||||
|
if user_id is None:
|
||||||
|
return False
|
||||||
|
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||||
|
user_id=user_id,
|
||||||
|
app_id=ctx.app.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_user_id(ctx: Context) -> str | None:
|
||||||
|
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||||
|
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||||
|
if ctx.subject_email is None:
|
||||||
|
return None
|
||||||
|
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||||
|
return str(account.id) if account is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
class MembershipStrategy:
|
||||||
|
"""Tenant-membership fallback.
|
||||||
|
|
||||||
|
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||||
|
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||||
|
denied (it requires the webapp-auth surface).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def authorize(self, ctx: Context) -> bool:
|
||||||
|
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||||
|
return False
|
||||||
|
if ctx.tenant is None:
|
||||||
|
return False
|
||||||
|
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||||
|
|
||||||
|
|
||||||
|
def _login_as(user) -> None:
|
||||||
|
"""Set Flask-Login request user so downstream services see the caller."""
|
||||||
|
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||||
|
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||||
|
|
||||||
|
|
||||||
|
class CallerMounter(Protocol):
|
||||||
|
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||||
|
|
||||||
|
def mount(self, ctx: Context) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class AccountMounter:
|
||||||
|
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||||
|
return subject_type == SubjectType.ACCOUNT
|
||||||
|
|
||||||
|
def mount(self, ctx: Context) -> None:
|
||||||
|
if ctx.account_id is None:
|
||||||
|
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||||
|
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||||
|
if account is None:
|
||||||
|
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||||
|
account.current_tenant = ctx.must_tenant
|
||||||
|
_login_as(account)
|
||||||
|
ctx.caller, ctx.caller_kind = account, "account"
|
||||||
|
|
||||||
|
|
||||||
|
class EndUserMounter:
|
||||||
|
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||||
|
return subject_type == SubjectType.EXTERNAL_SSO
|
||||||
|
|
||||||
|
def mount(self, ctx: Context) -> None:
|
||||||
|
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||||
|
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||||
|
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||||
|
InvokeFrom.OPENAPI,
|
||||||
|
tenant_id=ctx.tenant.id,
|
||||||
|
app_id=ctx.app.id,
|
||||||
|
user_id=ctx.subject_email,
|
||||||
|
)
|
||||||
|
_login_as(end_user)
|
||||||
|
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||||
89
api/controllers/openapi/auth/surface_gate.py
Normal file
89
api/controllers/openapi/auth/surface_gate.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
"""Surface gate.
|
||||||
|
|
||||||
|
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||||
|
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||||
|
audit emit + canonical-path message are single-sourced.
|
||||||
|
|
||||||
|
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||||
|
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||||
|
``openapi.wrong_surface_denied``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.openapi._audit import emit_wrong_surface
|
||||||
|
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
|
||||||
|
|
||||||
|
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||||
|
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||||
|
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||||
|
}
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., object])
|
||||||
|
|
||||||
|
|
||||||
|
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||||
|
"""Enforce that the resolved subject is in ``accepted``.
|
||||||
|
|
||||||
|
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
|
||||||
|
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
|
||||||
|
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
|
||||||
|
set the bearer layer didn't run — that's a wiring bug, not a
|
||||||
|
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||||
|
a silent 403.
|
||||||
|
"""
|
||||||
|
ctx = try_get_auth_ctx()
|
||||||
|
if ctx is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
|
||||||
|
)
|
||||||
|
|
||||||
|
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||||
|
if subject in accepted:
|
||||||
|
return
|
||||||
|
|
||||||
|
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||||
|
emit_wrong_surface(
|
||||||
|
subject_type=subject.value if subject else None,
|
||||||
|
attempted_path=request.path,
|
||||||
|
client_id=getattr(ctx, "client_id", None),
|
||||||
|
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||||
|
)
|
||||||
|
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||||
|
|
||||||
|
|
||||||
|
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||||
|
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||||
|
|
||||||
|
def deco(fn: F) -> F:
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapper(*args: object, **kwargs: object) -> object:
|
||||||
|
check_surface(accepted_set)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore[return-value]
|
||||||
|
|
||||||
|
return deco
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, SubjectType):
|
||||||
|
return raw
|
||||||
|
if isinstance(raw, str):
|
||||||
|
return SubjectType(raw)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify(value: object) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return str(value)
|
||||||
72
api/controllers/openapi/files.py
Normal file
72
api/controllers/openapi/files.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from flask_restx.api import HTTPStatus
|
||||||
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.common.errors import (
|
||||||
|
BlockedFileExtensionError,
|
||||||
|
FilenameNotExistsError,
|
||||||
|
FileTooLargeError,
|
||||||
|
NoFileUploadedError,
|
||||||
|
TooManyFilesError,
|
||||||
|
UnsupportedFileTypeError,
|
||||||
|
)
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.file_fields import FileResponse
|
||||||
|
from libs.oauth_bearer import Scope
|
||||||
|
from models import Account, App
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/files/upload")
|
||||||
|
class AppFileUploadApi(Resource):
|
||||||
|
@openapi_ns.doc("upload_file_for_app_input")
|
||||||
|
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
|
||||||
|
@openapi_ns.doc(
|
||||||
|
responses={
|
||||||
|
201: "File uploaded successfully",
|
||||||
|
400: "Bad request — no file or filename missing",
|
||||||
|
401: "Unauthorized — invalid or expired bearer token",
|
||||||
|
413: "File too large",
|
||||||
|
415: "Unsupported file type or blocked extension",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||||
|
if "file" not in request.files:
|
||||||
|
raise NoFileUploadedError()
|
||||||
|
if len(request.files) > 1:
|
||||||
|
raise TooManyFilesError()
|
||||||
|
|
||||||
|
file = request.files["file"]
|
||||||
|
if not file.mimetype:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
if not file.filename:
|
||||||
|
raise FilenameNotExistsError()
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file = FileService(db.engine).upload_file(
|
||||||
|
filename=file.filename,
|
||||||
|
content=file.stream.read(),
|
||||||
|
mimetype=file.mimetype,
|
||||||
|
user=caller,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise BadRequest(str(exc))
|
||||||
|
except services.errors.file.FileTooLargeError as exc:
|
||||||
|
raise FileTooLargeError(exc.description)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
except services.errors.file.BlockedFileExtensionError as exc:
|
||||||
|
raise BlockedFileExtensionError(exc.description)
|
||||||
|
|
||||||
|
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||||
|
return response.model_dump(mode="json"), 201
|
||||||
107
api/controllers/openapi/human_input_form.py
Normal file
107
api/controllers/openapi/human_input_form.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
OpenAPI bearer-authed human input form endpoints.
|
||||||
|
|
||||||
|
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
|
||||||
|
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import Response, request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from werkzeug.exceptions import BadRequest, NotFound
|
||||||
|
|
||||||
|
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||||
|
from controllers.common.schema import register_schema_models
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import to_timestamp
|
||||||
|
from libs.oauth_bearer import Scope
|
||||||
|
from models.model import App
|
||||||
|
from services.human_input_service import FormNotFoundError, HumanInputService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
|
||||||
|
|
||||||
|
|
||||||
|
def _jsonify_form_definition(form) -> Response:
|
||||||
|
definition_payload = form.get_definition().model_dump()
|
||||||
|
payload = {
|
||||||
|
"form_content": definition_payload["rendered_content"],
|
||||||
|
"inputs": definition_payload["inputs"],
|
||||||
|
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||||
|
"user_actions": definition_payload["user_actions"],
|
||||||
|
"expiration_time": to_timestamp(form.expiration_time),
|
||||||
|
}
|
||||||
|
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||||
|
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||||
|
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||||
|
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Form definition")
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||||
|
service = HumanInputService(db.engine)
|
||||||
|
form = service.get_form_by_token(form_token)
|
||||||
|
if form is None:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
_ensure_form_belongs_to_app(form, app_model)
|
||||||
|
_ensure_form_is_allowed_for_openapi(form)
|
||||||
|
service.ensure_form_active(form)
|
||||||
|
return _jsonify_form_definition(form)
|
||||||
|
|
||||||
|
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||||
|
@openapi_ns.response(200, "Form submitted")
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||||
|
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||||
|
|
||||||
|
service = HumanInputService(db.engine)
|
||||||
|
form = service.get_form_by_token(form_token)
|
||||||
|
if form is None:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
_ensure_form_belongs_to_app(form, app_model)
|
||||||
|
_ensure_form_is_allowed_for_openapi(form)
|
||||||
|
|
||||||
|
submission_user_id: str | None = None
|
||||||
|
submission_end_user_id: str | None = None
|
||||||
|
if caller_kind == "account":
|
||||||
|
submission_user_id = caller.id
|
||||||
|
else:
|
||||||
|
submission_end_user_id = caller.id
|
||||||
|
|
||||||
|
if form.recipient_type is None:
|
||||||
|
logger.warning("Recipient type is None for form, form_token=%s", form_token)
|
||||||
|
raise BadRequest("Form recipient type is invalid")
|
||||||
|
|
||||||
|
try:
|
||||||
|
service.submit_form_by_token(
|
||||||
|
recipient_type=form.recipient_type,
|
||||||
|
form_token=form_token,
|
||||||
|
selected_action_id=payload.action,
|
||||||
|
form_data=payload.inputs,
|
||||||
|
submission_user_id=submission_user_id,
|
||||||
|
submission_end_user_id=submission_end_user_id,
|
||||||
|
)
|
||||||
|
except FormNotFoundError:
|
||||||
|
raise NotFound("Form not found")
|
||||||
|
|
||||||
|
return {}, 200
|
||||||
9
api/controllers/openapi/index.py
Normal file
9
api/controllers/openapi/index.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from flask_restx import Resource
|
||||||
|
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/_health")
|
||||||
|
class HealthApi(Resource):
|
||||||
|
def get(self):
|
||||||
|
return {"ok": True}
|
||||||
398
api/controllers/openapi/oauth_device.py
Normal file
398
api/controllers/openapi/oauth_device.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||||
|
sub-groups in one module:
|
||||||
|
|
||||||
|
Protocol (RFC 8628, public + rate-limited):
|
||||||
|
POST /oauth/device/code
|
||||||
|
POST /oauth/device/token
|
||||||
|
GET /oauth/device/lookup
|
||||||
|
|
||||||
|
Approval (account branch, console-cookie authed):
|
||||||
|
POST /oauth/device/approve
|
||||||
|
POST /oauth/device/deny
|
||||||
|
|
||||||
|
SSO branch lives in oauth_device_sso.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_login import login_required
|
||||||
|
from flask_restx import Resource
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.common.schema import query_params_from_model
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._models import (
|
||||||
|
AccountPayload,
|
||||||
|
DeviceCodeRequest,
|
||||||
|
DeviceCodeResponse,
|
||||||
|
DeviceLookupQuery,
|
||||||
|
DeviceLookupResponse,
|
||||||
|
DeviceMutateRequest,
|
||||||
|
DeviceMutateResponse,
|
||||||
|
DevicePollRequest,
|
||||||
|
WorkspacePayload,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.helper import extract_remote_ip
|
||||||
|
from libs.login import current_account_with_tenant
|
||||||
|
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||||
|
from libs.rate_limit import (
|
||||||
|
LIMIT_APPROVE_CONSOLE,
|
||||||
|
LIMIT_DEVICE_CODE_PER_IP,
|
||||||
|
LIMIT_LOOKUP_PUBLIC,
|
||||||
|
rate_limit,
|
||||||
|
)
|
||||||
|
from services.account_service import TenantService
|
||||||
|
from services.oauth_device_flow import (
|
||||||
|
ACCOUNT_ISSUER_SENTINEL,
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||||
|
DEVICE_FLOW_TTL_SECONDS,
|
||||||
|
DeviceFlowRedis,
|
||||||
|
DeviceFlowStatus,
|
||||||
|
InvalidTransitionError,
|
||||||
|
PollPayload,
|
||||||
|
SlowDownDecision,
|
||||||
|
StateNotFoundError,
|
||||||
|
mint_oauth_token,
|
||||||
|
oauth_ttl_days,
|
||||||
|
)
|
||||||
|
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Validation helpers
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||||
|
body = request.get_json(silent=True) or {}
|
||||||
|
try:
|
||||||
|
return model.model_validate(body)
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise BadRequest(str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||||
|
try:
|
||||||
|
return model.model_validate(request.args.to_dict(flat=True))
|
||||||
|
except ValidationError as exc:
|
||||||
|
raise BadRequest(str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/oauth/device/code")
|
||||||
|
class OAuthDeviceCodeApi(Resource):
|
||||||
|
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||||
|
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||||
|
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||||
|
def post(self):
|
||||||
|
payload = _validate_json(DeviceCodeRequest)
|
||||||
|
client_id = payload.client_id
|
||||||
|
device_label = payload.device_label
|
||||||
|
|
||||||
|
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||||
|
return {"error": "unsupported_client"}, 400
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
ip = extract_remote_ip(request)
|
||||||
|
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"device_code": device_code,
|
||||||
|
"user_code": user_code,
|
||||||
|
"verification_uri": _verification_uri(),
|
||||||
|
"expires_in": expires_in,
|
||||||
|
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/oauth/device/token")
|
||||||
|
class OAuthDeviceTokenApi(Resource):
|
||||||
|
"""RFC 8628 poll."""
|
||||||
|
|
||||||
|
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||||
|
def post(self):
|
||||||
|
payload = _validate_json(DevicePollRequest)
|
||||||
|
device_code = payload.device_code
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
|
||||||
|
# slow_down beats every other branch — polling-too-fast clients
|
||||||
|
# see only that response regardless of underlying state.
|
||||||
|
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||||
|
return {"error": "slow_down"}, 400
|
||||||
|
|
||||||
|
state = store.load_by_device_code(device_code)
|
||||||
|
if state is None:
|
||||||
|
return {"error": "expired_token"}, 400
|
||||||
|
|
||||||
|
if state.status is DeviceFlowStatus.PENDING:
|
||||||
|
return {"error": "authorization_pending"}, 400
|
||||||
|
|
||||||
|
terminal = store.consume_on_poll(device_code)
|
||||||
|
if terminal is None:
|
||||||
|
return {"error": "expired_token"}, 400
|
||||||
|
|
||||||
|
if terminal.status is DeviceFlowStatus.DENIED:
|
||||||
|
return {"error": "access_denied"}, 400
|
||||||
|
|
||||||
|
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
|
||||||
|
if "token" not in poll_payload:
|
||||||
|
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||||
|
return {"error": "expired_token"}, 400
|
||||||
|
|
||||||
|
_audit_cross_ip_if_needed(state)
|
||||||
|
return poll_payload, 200
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/oauth/device/lookup")
|
||||||
|
class OAuthDeviceLookupApi(Resource):
|
||||||
|
"""Read-only — public for pre-validate before login. user_code is
|
||||||
|
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||||
|
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||||
|
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||||
|
def get(self):
|
||||||
|
payload = _validate_query(DeviceLookupQuery)
|
||||||
|
user_code = payload.user_code.strip().upper()
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
found = store.load_by_user_code(user_code)
|
||||||
|
if found is None:
|
||||||
|
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||||
|
|
||||||
|
_device_code, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": True,
|
||||||
|
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||||
|
"client_id": state.client_id,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Approval endpoints — account branch (cookie-authed)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||||
|
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/oauth/device/approve")
|
||||||
|
class DeviceApproveApi(Resource):
|
||||||
|
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||||
|
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@bearer_feature_required
|
||||||
|
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||||
|
def post(self):
|
||||||
|
payload = _validate_json(DeviceMutateRequest)
|
||||||
|
user_code = payload.user_code.strip().upper()
|
||||||
|
|
||||||
|
account, tenant = current_account_with_tenant()
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
|
||||||
|
found = store.load_by_user_code(user_code)
|
||||||
|
if found is None:
|
||||||
|
return {"error": "expired_or_unknown"}, 404
|
||||||
|
device_code, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
return {"error": "already_resolved"}, 409
|
||||||
|
|
||||||
|
# SET NX guard — without it, two in-flight approves both pass
|
||||||
|
# PENDING, both mint, and the second upsert silently rotates the
|
||||||
|
# first caller into an already-revoked token.
|
||||||
|
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||||
|
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||||
|
return {"error": "approve_in_progress"}, 409
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||||
|
try:
|
||||||
|
validate_mint_policy(
|
||||||
|
subject_type=profile.subject_type,
|
||||||
|
prefix=profile.prefix,
|
||||||
|
scopes=profile.scopes,
|
||||||
|
)
|
||||||
|
except MintPolicyViolation as e:
|
||||||
|
raise BadRequest(description=str(e)) from None
|
||||||
|
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||||
|
mint = mint_oauth_token(
|
||||||
|
db.session,
|
||||||
|
redis_client,
|
||||||
|
subject_email=account.email,
|
||||||
|
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||||
|
account_id=str(account.id),
|
||||||
|
client_id=state.client_id,
|
||||||
|
device_label=state.device_label,
|
||||||
|
prefix=profile.prefix,
|
||||||
|
ttl_days=ttl_days,
|
||||||
|
)
|
||||||
|
|
||||||
|
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||||
|
try:
|
||||||
|
store.approve(
|
||||||
|
device_code,
|
||||||
|
subject_email=account.email,
|
||||||
|
account_id=str(account.id),
|
||||||
|
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||||
|
minted_token=mint.token,
|
||||||
|
token_id=str(mint.token_id),
|
||||||
|
poll_payload=poll_payload,
|
||||||
|
)
|
||||||
|
except (StateNotFoundError, InvalidTransitionError):
|
||||||
|
# Row minted but state vanished — roll forward; the orphan
|
||||||
|
# token is revocable via auth devices list / Authorized Apps.
|
||||||
|
logger.exception("device_flow: approve raced on %s", device_code)
|
||||||
|
return {"error": "state_lost"}, 409
|
||||||
|
finally:
|
||||||
|
redis_client.delete(guard_key)
|
||||||
|
|
||||||
|
_emit_approve_audit(state, account, tenant, mint)
|
||||||
|
return {"status": "approved"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/oauth/device/deny")
|
||||||
|
class DeviceDenyApi(Resource):
|
||||||
|
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||||
|
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@bearer_feature_required
|
||||||
|
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||||
|
def post(self):
|
||||||
|
payload = _validate_json(DeviceMutateRequest)
|
||||||
|
user_code = payload.user_code.strip().upper()
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
found = store.load_by_user_code(user_code)
|
||||||
|
if found is None:
|
||||||
|
return {"error": "expired_or_unknown"}, 404
|
||||||
|
device_code, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
return {"error": "already_resolved"}, 409
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.deny(device_code)
|
||||||
|
except (StateNotFoundError, InvalidTransitionError):
|
||||||
|
logger.exception("device_flow: deny raced on %s", device_code)
|
||||||
|
return {"error": "state_lost"}, 409
|
||||||
|
|
||||||
|
_emit_deny_audit(state)
|
||||||
|
return {"status": "denied"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Helpers
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _verification_uri() -> str:
|
||||||
|
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||||
|
if base:
|
||||||
|
return f"{base.rstrip('/')}/device"
|
||||||
|
return f"{request.host_url.rstrip('/')}/device"
|
||||||
|
|
||||||
|
|
||||||
|
def _audit_cross_ip_if_needed(state) -> None:
|
||||||
|
poll_ip = extract_remote_ip(request)
|
||||||
|
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||||
|
logger.warning(
|
||||||
|
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||||
|
state.token_id,
|
||||||
|
state.created_ip,
|
||||||
|
poll_ip,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"token_id": state.token_id,
|
||||||
|
"creation_ip": state.created_ip,
|
||||||
|
"poll_ip": poll_ip,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||||
|
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
|
||||||
|
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||||
|
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||||
|
default_ws_id = None
|
||||||
|
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||||
|
default_ws_id = str(tenant)
|
||||||
|
if default_ws_id is None:
|
||||||
|
for _t, m in rows:
|
||||||
|
if getattr(m, "current", False):
|
||||||
|
default_ws_id = str(m.tenant_id)
|
||||||
|
break
|
||||||
|
if default_ws_id is None and workspaces:
|
||||||
|
default_ws_id = workspaces[0].id
|
||||||
|
|
||||||
|
payload: PollPayload = {
|
||||||
|
"token": mint.token,
|
||||||
|
"expires_at": mint.expires_at.isoformat(),
|
||||||
|
"subject_type": SubjectType.ACCOUNT,
|
||||||
|
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||||
|
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||||
|
"default_workspace_id": default_ws_id,
|
||||||
|
"token_id": str(mint.token_id),
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||||
|
logger.warning(
|
||||||
|
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||||
|
mint.token_id,
|
||||||
|
account.email,
|
||||||
|
state.client_id,
|
||||||
|
state.device_label,
|
||||||
|
mint.expires_at,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": "oauth.device_flow_approved",
|
||||||
|
"token_id": str(mint.token_id),
|
||||||
|
"subject_type": SubjectType.ACCOUNT,
|
||||||
|
"subject_email": account.email,
|
||||||
|
"account_id": str(account.id),
|
||||||
|
"tenant_id": tenant,
|
||||||
|
"client_id": state.client_id,
|
||||||
|
"device_label": state.device_label,
|
||||||
|
"scopes": ["full"],
|
||||||
|
"expires_at": mint.expires_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_deny_audit(state) -> None:
|
||||||
|
logger.warning(
|
||||||
|
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||||
|
state.client_id,
|
||||||
|
state.device_label,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": "oauth.device_flow_denied",
|
||||||
|
"client_id": state.client_id,
|
||||||
|
"device_label": state.device_label,
|
||||||
|
},
|
||||||
|
)
|
||||||
348
api/controllers/openapi/oauth_device_sso.py
Normal file
348
api/controllers/openapi/oauth_device_sso.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||||
|
EE-only. Browser flow:
|
||||||
|
|
||||||
|
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||||
|
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||||
|
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||||
|
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||||
|
|
||||||
|
Function-based (raw @bp.route) rather than Resource classes because the
|
||||||
|
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from flask import jsonify, make_response, redirect, request
|
||||||
|
from werkzeug.exceptions import (
|
||||||
|
BadGateway,
|
||||||
|
BadRequest,
|
||||||
|
Conflict,
|
||||||
|
Forbidden,
|
||||||
|
NotFound,
|
||||||
|
Unauthorized,
|
||||||
|
)
|
||||||
|
|
||||||
|
from controllers.openapi import bp
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs import jws
|
||||||
|
from libs.device_flow_security import (
|
||||||
|
APPROVAL_GRANT_COOKIE_NAME,
|
||||||
|
ApprovalGrantClaims,
|
||||||
|
approval_grant_cleared_cookie_kwargs,
|
||||||
|
approval_grant_cookie_kwargs,
|
||||||
|
consume_approval_grant_nonce,
|
||||||
|
consume_sso_assertion_nonce,
|
||||||
|
enterprise_only,
|
||||||
|
mint_approval_grant,
|
||||||
|
verify_approval_grant,
|
||||||
|
)
|
||||||
|
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||||
|
from libs.rate_limit import (
|
||||||
|
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||||
|
LIMIT_SSO_INITIATE_PER_IP,
|
||||||
|
enforce,
|
||||||
|
rate_limit,
|
||||||
|
)
|
||||||
|
from services.account_service import AccountService
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.oauth_device_flow import (
|
||||||
|
DeviceFlowRedis,
|
||||||
|
DeviceFlowStatus,
|
||||||
|
InvalidTransitionError,
|
||||||
|
PollPayload,
|
||||||
|
StateNotFoundError,
|
||||||
|
mint_oauth_token,
|
||||||
|
oauth_ttl_days,
|
||||||
|
)
|
||||||
|
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||||
|
# device_code it references.
|
||||||
|
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||||
|
|
||||||
|
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||||
|
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||||
|
@enterprise_only
|
||||||
|
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||||
|
def sso_initiate():
|
||||||
|
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||||
|
if not user_code:
|
||||||
|
raise BadRequest("user_code required")
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
found = store.load_by_user_code(user_code)
|
||||||
|
if found is None:
|
||||||
|
raise BadRequest("invalid_user_code")
|
||||||
|
_, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
raise BadRequest("invalid_user_code")
|
||||||
|
|
||||||
|
keyset = jws.KeySet.from_shared_secret()
|
||||||
|
signed_state = jws.sign(
|
||||||
|
keyset,
|
||||||
|
payload={
|
||||||
|
"redirect_url": "",
|
||||||
|
"app_code": "",
|
||||||
|
"intent": "device_flow",
|
||||||
|
"user_code": user_code,
|
||||||
|
"nonce": secrets.token_urlsafe(16),
|
||||||
|
"return_to": "",
|
||||||
|
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||||
|
},
|
||||||
|
aud=jws.AUD_STATE_ENVELOPE,
|
||||||
|
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||||
|
raise BadGateway("sso_initiate_failed") from e
|
||||||
|
|
||||||
|
url = (reply or {}).get("url")
|
||||||
|
if not url:
|
||||||
|
raise BadGateway("sso_initiate_missing_url")
|
||||||
|
|
||||||
|
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||||
|
resp = redirect(url, code=302)
|
||||||
|
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||||
|
@enterprise_only
|
||||||
|
def sso_complete():
|
||||||
|
blob = request.args.get("sso_assertion")
|
||||||
|
if not blob:
|
||||||
|
raise BadRequest("sso_assertion required")
|
||||||
|
|
||||||
|
keyset = jws.KeySet.from_shared_secret()
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||||
|
except jws.VerifyError as e:
|
||||||
|
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||||
|
raise BadRequest("invalid_sso_assertion") from e
|
||||||
|
|
||||||
|
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||||
|
raise BadRequest("invalid_sso_assertion")
|
||||||
|
|
||||||
|
user_code = (claims.get("user_code") or "").strip().upper()
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
found = store.load_by_user_code(user_code)
|
||||||
|
if found is None:
|
||||||
|
raise Conflict("user_code_not_pending")
|
||||||
|
_, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
raise Conflict("user_code_not_pending")
|
||||||
|
|
||||||
|
if AccountService.has_active_account_with_email(db.session, claims["email"]):
|
||||||
|
_emit_external_rejection_audit(
|
||||||
|
state,
|
||||||
|
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
|
||||||
|
reason="email_belongs_to_dify_account",
|
||||||
|
)
|
||||||
|
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||||
|
|
||||||
|
iss = request.host_url.rstrip("/")
|
||||||
|
cookie_value, _ = mint_approval_grant(
|
||||||
|
keyset=keyset,
|
||||||
|
iss=iss,
|
||||||
|
subject_email=claims["email"],
|
||||||
|
subject_issuer=claims["issuer"],
|
||||||
|
user_code=user_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = redirect("/device?sso_verified=1", code=302)
|
||||||
|
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||||
|
@enterprise_only
|
||||||
|
def approval_context():
|
||||||
|
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||||
|
if not token:
|
||||||
|
raise Unauthorized("no_session")
|
||||||
|
|
||||||
|
keyset = jws.KeySet.from_shared_secret()
|
||||||
|
try:
|
||||||
|
claims = verify_approval_grant(keyset, token)
|
||||||
|
except jws.VerifyError as e:
|
||||||
|
logger.warning("approval-context: bad cookie: %s", e)
|
||||||
|
raise Unauthorized("no_session") from e
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"subject_email": claims.subject_email,
|
||||||
|
"subject_issuer": claims.subject_issuer,
|
||||||
|
"user_code": claims.user_code,
|
||||||
|
"csrf_token": claims.csrf_token,
|
||||||
|
"expires_at": claims.expires_at.isoformat(),
|
||||||
|
}
|
||||||
|
), 200
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||||
|
@enterprise_only
|
||||||
|
def approve_external():
|
||||||
|
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||||
|
if not token:
|
||||||
|
raise Unauthorized("invalid_session")
|
||||||
|
|
||||||
|
keyset = jws.KeySet.from_shared_secret()
|
||||||
|
try:
|
||||||
|
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||||
|
except jws.VerifyError as e:
|
||||||
|
logger.warning("approve-external: bad cookie: %s", e)
|
||||||
|
raise Unauthorized("invalid_session") from e
|
||||||
|
|
||||||
|
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||||
|
|
||||||
|
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||||
|
if not csrf_header or csrf_header != claims.csrf_token:
|
||||||
|
raise Forbidden("csrf_mismatch")
|
||||||
|
|
||||||
|
data = request.get_json(silent=True) or {}
|
||||||
|
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||||
|
if body_user_code != claims.user_code:
|
||||||
|
raise BadRequest("user_code_mismatch")
|
||||||
|
|
||||||
|
store = DeviceFlowRedis(redis_client)
|
||||||
|
found = store.load_by_user_code(claims.user_code)
|
||||||
|
if found is None:
|
||||||
|
raise NotFound("user_code_not_pending")
|
||||||
|
device_code, state = found
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
raise Conflict("user_code_not_pending")
|
||||||
|
|
||||||
|
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
|
||||||
|
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||||
|
raise Forbidden("email_belongs_to_dify_account")
|
||||||
|
|
||||||
|
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||||
|
raise Unauthorized("session_already_consumed")
|
||||||
|
|
||||||
|
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||||
|
try:
|
||||||
|
validate_mint_policy(
|
||||||
|
subject_type=profile.subject_type,
|
||||||
|
prefix=profile.prefix,
|
||||||
|
scopes=profile.scopes,
|
||||||
|
)
|
||||||
|
except MintPolicyViolation as e:
|
||||||
|
raise BadRequest(description=str(e)) from None
|
||||||
|
|
||||||
|
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||||
|
mint = mint_oauth_token(
|
||||||
|
db.session,
|
||||||
|
redis_client,
|
||||||
|
subject_email=claims.subject_email,
|
||||||
|
subject_issuer=claims.subject_issuer,
|
||||||
|
account_id=None,
|
||||||
|
client_id=state.client_id,
|
||||||
|
device_label=state.device_label,
|
||||||
|
prefix=profile.prefix,
|
||||||
|
ttl_days=ttl_days,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SSO branch of the shared PollPayload contract: account/workspace
|
||||||
|
# fields are zero-filled (`None` / `[]`) for parity with the account
|
||||||
|
# branch in `oauth_device._build_account_poll_payload`.
|
||||||
|
poll_payload: PollPayload = {
|
||||||
|
"token": mint.token,
|
||||||
|
"expires_at": mint.expires_at.isoformat(),
|
||||||
|
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||||
|
"subject_email": claims.subject_email,
|
||||||
|
"subject_issuer": claims.subject_issuer,
|
||||||
|
"account": None,
|
||||||
|
"workspaces": [],
|
||||||
|
"default_workspace_id": None,
|
||||||
|
"token_id": str(mint.token_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
store.approve(
|
||||||
|
device_code,
|
||||||
|
subject_email=claims.subject_email,
|
||||||
|
account_id=None,
|
||||||
|
subject_issuer=claims.subject_issuer,
|
||||||
|
minted_token=mint.token,
|
||||||
|
token_id=str(mint.token_id),
|
||||||
|
poll_payload=poll_payload,
|
||||||
|
)
|
||||||
|
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||||
|
logger.exception("approve-external: state transition raced")
|
||||||
|
raise Conflict("state_lost") from e
|
||||||
|
|
||||||
|
_emit_approve_external_audit(state, claims, mint)
|
||||||
|
|
||||||
|
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||||
|
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _RejectedClaims:
|
||||||
|
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||||
|
|
||||||
|
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||||
|
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||||
|
event without reaching for the full dataclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subject_email: str
|
||||||
|
subject_issuer: str
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||||
|
logger.warning(
|
||||||
|
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||||
|
SubjectType.EXTERNAL_SSO,
|
||||||
|
claims.subject_email,
|
||||||
|
claims.subject_issuer,
|
||||||
|
reason,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": "oauth.device_flow_rejected",
|
||||||
|
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||||
|
"subject_email": claims.subject_email,
|
||||||
|
"subject_issuer": claims.subject_issuer,
|
||||||
|
"reason": reason,
|
||||||
|
"client_id": state.client_id,
|
||||||
|
"device_label": state.device_label,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||||
|
logger.warning(
|
||||||
|
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||||
|
SubjectType.EXTERNAL_SSO,
|
||||||
|
claims.subject_email,
|
||||||
|
claims.subject_issuer,
|
||||||
|
mint.token_id,
|
||||||
|
extra={
|
||||||
|
"audit": True,
|
||||||
|
"event": "oauth.device_flow_approved",
|
||||||
|
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||||
|
"subject_email": claims.subject_email,
|
||||||
|
"subject_issuer": claims.subject_issuer,
|
||||||
|
"token_id": str(mint.token_id),
|
||||||
|
"client_id": state.client_id,
|
||||||
|
"device_label": state.device_label,
|
||||||
|
"scopes": ["apps:run"],
|
||||||
|
"expires_at": mint.expires_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
119
api/controllers/openapi/workflow_events.py
Normal file
119
api/controllers/openapi/workflow_events.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
"""
|
||||||
|
OpenAPI bearer-authed workflow reconnect event stream endpoint.
|
||||||
|
|
||||||
|
GET /apps/<app_id>/tasks/<task_id>/events
|
||||||
|
— reconnect to the SSE stream for a paused/running workflow run.
|
||||||
|
`task_id` is treated as `workflow_run_id`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from flask import Response, request
|
||||||
|
from flask_restx import Resource
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||||
|
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||||
|
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||||
|
from core.app.apps.message_generator import MessageGenerator
|
||||||
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
from core.app.entities.task_entities import StreamEvent
|
||||||
|
from core.workflow.human_input_policy import HumanInputSurface
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.oauth_bearer import Scope
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import App, AppMode
|
||||||
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
|
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||||
|
class OpenApiWorkflowEventsApi(Resource):
|
||||||
|
@openapi_ns.response(200, "SSE event stream")
|
||||||
|
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||||
|
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||||
|
app_mode = AppMode.value_of(app_model.mode)
|
||||||
|
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||||
|
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||||
|
|
||||||
|
session_maker = sessionmaker(db.engine)
|
||||||
|
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||||
|
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
run_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if workflow_run is None:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
if workflow_run.app_id != app_model.id:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
if caller_kind == "account":
|
||||||
|
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
else:
|
||||||
|
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
|
||||||
|
raise NotFound("Workflow run not found")
|
||||||
|
|
||||||
|
workflow_run_entity = workflow_run
|
||||||
|
|
||||||
|
if workflow_run_entity.finished_at is not None:
|
||||||
|
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||||
|
task_id=workflow_run_entity.id,
|
||||||
|
workflow_run=workflow_run_entity,
|
||||||
|
creator_user=caller,
|
||||||
|
)
|
||||||
|
payload = response.model_dump(mode="json")
|
||||||
|
payload["event"] = response.event.value
|
||||||
|
|
||||||
|
def _generate_finished_events() -> Generator[str, None, None]:
|
||||||
|
yield f"data: {json.dumps(payload)}\n\n"
|
||||||
|
|
||||||
|
event_generator = _generate_finished_events
|
||||||
|
else:
|
||||||
|
msg_generator = MessageGenerator()
|
||||||
|
generator: BaseAppGenerator
|
||||||
|
if app_mode == AppMode.ADVANCED_CHAT:
|
||||||
|
generator = AdvancedChatAppGenerator()
|
||||||
|
else:
|
||||||
|
generator = WorkflowAppGenerator()
|
||||||
|
|
||||||
|
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||||
|
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||||
|
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||||
|
|
||||||
|
def _generate_stream_events():
|
||||||
|
if include_state_snapshot:
|
||||||
|
return generator.convert_to_event_stream(
|
||||||
|
build_workflow_event_stream(
|
||||||
|
app_mode=app_mode,
|
||||||
|
workflow_run=workflow_run_entity,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
session_maker=session_maker,
|
||||||
|
human_input_surface=HumanInputSurface.OPENAPI,
|
||||||
|
close_on_pause=not continue_on_pause,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return generator.convert_to_event_stream(
|
||||||
|
msg_generator.retrieve_events(
|
||||||
|
app_mode,
|
||||||
|
workflow_run_entity.id,
|
||||||
|
terminal_events=terminal_events,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
event_generator = _generate_stream_events
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
event_generator(),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
|
)
|
||||||
78
api/controllers/openapi/workspaces.py
Normal file
78
api/controllers/openapi/workspaces.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||||
|
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||||
|
|
||||||
|
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||||
|
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||||
|
matches /openapi/v1/account.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from itertools import starmap
|
||||||
|
|
||||||
|
from flask_restx import Resource
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||||
|
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.oauth_bearer import (
|
||||||
|
ACCEPT_USER_ANY,
|
||||||
|
SubjectType,
|
||||||
|
get_auth_ctx,
|
||||||
|
validate_bearer,
|
||||||
|
)
|
||||||
|
from models import Tenant, TenantAccountJoin
|
||||||
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/workspaces")
|
||||||
|
class WorkspacesApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
@accept_subjects(SubjectType.ACCOUNT)
|
||||||
|
def get(self):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
|
||||||
|
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||||
|
|
||||||
|
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||||
|
class WorkspaceByIdApi(Resource):
|
||||||
|
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||||
|
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||||
|
@accept_subjects(SubjectType.ACCOUNT)
|
||||||
|
def get(self, workspace_id: str):
|
||||||
|
ctx = get_auth_ctx()
|
||||||
|
|
||||||
|
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||||
|
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||||
|
if row is None:
|
||||||
|
raise NotFound("workspace not found")
|
||||||
|
|
||||||
|
tenant, membership = row
|
||||||
|
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||||
|
return WorkspaceSummaryResponse(
|
||||||
|
id=str(tenant.id),
|
||||||
|
name=tenant.name,
|
||||||
|
role=getattr(membership, "role", ""),
|
||||||
|
status=tenant.status,
|
||||||
|
current=getattr(membership, "current", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||||
|
return WorkspaceDetailResponse(
|
||||||
|
id=str(tenant.id),
|
||||||
|
name=tenant.name,
|
||||||
|
role=getattr(membership, "role", ""),
|
||||||
|
status=tenant.status,
|
||||||
|
current=getattr(membership, "current", False),
|
||||||
|
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||||
|
)
|
||||||
@ -16,7 +16,7 @@ from libs.passport import PassportService
|
|||||||
from libs.token import extract_webapp_passport
|
from libs.token import extract_webapp_passport
|
||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
from services.webapp_auth_service import WebAppAuthService
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
|||||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||||
if not webapp_settings:
|
if not webapp_settings:
|
||||||
raise NotFound("Web app settings not found.")
|
raise NotFound("Web app settings not found.")
|
||||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
|
||||||
|
|
||||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||||
_validate_user_accessibility(
|
_validate_user_accessibility(
|
||||||
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
|||||||
raise Unauthorized("Please re-login to access the web app.")
|
raise Unauthorized("Please re-login to access the web app.")
|
||||||
app_id = AppService.get_app_id_by_code(app_code)
|
app_id = AppService.get_app_id_by_code(app_code)
|
||||||
app_web_auth_enabled = (
|
app_web_auth_enabled = (
|
||||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
|
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
|
||||||
|
!= WebAppAccessMode.PUBLIC
|
||||||
)
|
)
|
||||||
if app_web_auth_enabled:
|
if app_web_auth_enabled:
|
||||||
raise WebAppAuthRequiredError()
|
raise WebAppAuthRequiredError()
|
||||||
|
|||||||
@ -198,7 +198,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
),
|
),
|
||||||
query=query,
|
query=query,
|
||||||
files=list(file_objs),
|
files=list(file_objs),
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=(
|
||||||
|
args.get("parent_message_id")
|
||||||
|
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||||
|
else UUID_NIL
|
||||||
|
),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|||||||
@ -167,7 +167,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
),
|
),
|
||||||
query=query,
|
query=query,
|
||||||
files=list(file_objs),
|
files=list(file_objs),
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=(
|
||||||
|
args.get("parent_message_id")
|
||||||
|
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||||
|
else UUID_NIL
|
||||||
|
),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|||||||
@ -161,7 +161,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
),
|
),
|
||||||
query=query,
|
query=query,
|
||||||
files=list(file_objs),
|
files=list(file_objs),
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=(
|
||||||
|
args.get("parent_message_id")
|
||||||
|
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||||
|
else UUID_NIL
|
||||||
|
),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
extras=extras,
|
extras=extras,
|
||||||
|
|||||||
@ -53,6 +53,14 @@ from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
|||||||
from core.trigger.trigger_manager import TriggerManager
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||||
|
|
||||||
|
# Maps the entry surface a workflow was invoked from to the HITL surface that
|
||||||
|
# its resume tokens must be filtered for. Surfaces not in this map fall back to
|
||||||
|
# the general priority ordering (typically CONSOLE > BACKSTAGE).
|
||||||
|
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
|
||||||
|
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
|
||||||
|
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
|
||||||
|
}
|
||||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -340,11 +348,7 @@ class WorkflowResponseConverter:
|
|||||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||||
human_input_form_ids,
|
human_input_form_ids,
|
||||||
session=session,
|
session=session,
|
||||||
surface=(
|
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
|
||||||
HumanInputSurface.SERVICE_API
|
|
||||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||||
|
|||||||
@ -731,6 +731,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
match invoke_from:
|
match invoke_from:
|
||||||
case InvokeFrom.SERVICE_API:
|
case InvokeFrom.SERVICE_API:
|
||||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||||
|
case InvokeFrom.OPENAPI:
|
||||||
|
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||||
case InvokeFrom.EXPLORE:
|
case InvokeFrom.EXPLORE:
|
||||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||||
case InvokeFrom.WEB_APP:
|
case InvokeFrom.WEB_APP:
|
||||||
|
|||||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
|||||||
|
|
||||||
class InvokeFrom(StrEnum):
|
class InvokeFrom(StrEnum):
|
||||||
SERVICE_API = "service-api"
|
SERVICE_API = "service-api"
|
||||||
|
OPENAPI = "openapi"
|
||||||
WEB_APP = "web-app"
|
WEB_APP = "web-app"
|
||||||
TRIGGER = "trigger"
|
TRIGGER = "trigger"
|
||||||
EXPLORE = "explore"
|
EXPLORE = "explore"
|
||||||
@ -42,6 +43,7 @@ class InvokeFrom(StrEnum):
|
|||||||
InvokeFrom.EXPLORE: "explore_app",
|
InvokeFrom.EXPLORE: "explore_app",
|
||||||
InvokeFrom.TRIGGER: "trigger",
|
InvokeFrom.TRIGGER: "trigger",
|
||||||
InvokeFrom.SERVICE_API: "api",
|
InvokeFrom.SERVICE_API: "api",
|
||||||
|
InvokeFrom.OPENAPI: "openapi",
|
||||||
}
|
}
|
||||||
return source_mapping.get(self, "dev")
|
return source_mapping.get(self, "dev")
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Iterable, Sequence
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
|
from threading import Lock
|
||||||
from typing import IO, Any, Literal, cast, overload, override
|
from typing import IO, Any, Literal, cast, overload, override
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -12,9 +13,9 @@ from configs import dify_config
|
|||||||
from core.llm_generator.output_parser.structured_output import (
|
from core.llm_generator.output_parser.structured_output import (
|
||||||
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
|
||||||
)
|
)
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.plugin.impl.asset import PluginAssetManager
|
from core.plugin.impl.asset import PluginAssetManager
|
||||||
from core.plugin.impl.model import PluginModelClient
|
from core.plugin.impl.model import PluginModelClient
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from graphon.model_runtime.entities.llm_entities import (
|
from graphon.model_runtime.entities.llm_entities import (
|
||||||
LLMResult,
|
LLMResult,
|
||||||
@ -100,36 +101,35 @@ class _PluginStructuredOutputModelInstance:
|
|||||||
|
|
||||||
|
|
||||||
class PluginModelRuntime(ModelRuntime):
|
class PluginModelRuntime(ModelRuntime):
|
||||||
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope.
|
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
|
||||||
|
|
||||||
Provider discovery goes through ``PluginService`` so the plugin lifecycle
|
|
||||||
methods and provider reads share one tenant-scoped cache owner.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
user_id: str | None
|
user_id: str | None
|
||||||
client: PluginModelClient
|
client: PluginModelClient
|
||||||
_plugin_service: type[PluginService]
|
_provider_entities: tuple[ProviderEntity, ...] | None
|
||||||
|
_provider_entities_lock: Lock
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
user_id: str | None,
|
|
||||||
client: PluginModelClient,
|
|
||||||
plugin_service: type[PluginService],
|
|
||||||
) -> None:
|
|
||||||
if client is None:
|
if client is None:
|
||||||
raise ValueError("client is required.")
|
raise ValueError("client is required.")
|
||||||
if plugin_service is None:
|
|
||||||
raise ValueError("plugin_service is required.")
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.client = client
|
self.client = client
|
||||||
self._plugin_service = plugin_service
|
self._provider_entities = None
|
||||||
|
self._provider_entities_lock = Lock()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||||
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client)
|
if self._provider_entities is not None:
|
||||||
|
return self._provider_entities
|
||||||
|
|
||||||
|
with self._provider_entities_lock:
|
||||||
|
if self._provider_entities is None:
|
||||||
|
self._provider_entities = tuple(
|
||||||
|
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._provider_entities
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||||
@ -628,6 +628,34 @@ class PluginModelRuntime(ModelRuntime):
|
|||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
|
||||||
|
"""
|
||||||
|
Expose a bare provider alias only for the canonical provider mapping.
|
||||||
|
|
||||||
|
Multiple plugins can publish the same short provider slug. If every
|
||||||
|
provider entity keeps that slug in ``provider_name``, callers that still
|
||||||
|
resolve by short name become order-dependent. Restrict the alias to the
|
||||||
|
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
||||||
|
remain deterministic while the runtime surface stays canonical.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
canonical_provider_id = ModelProviderID(provider.provider)
|
||||||
|
except ValueError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if canonical_provider_id.plugin_id != provider.plugin_id:
|
||||||
|
return ""
|
||||||
|
if canonical_provider_id.provider_name != provider.provider:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return provider.provider
|
||||||
|
|
||||||
|
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
|
||||||
|
declaration = provider.declaration.model_copy(deep=True)
|
||||||
|
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
||||||
|
declaration.provider_name = self._get_provider_short_name_alias(provider)
|
||||||
|
return declaration
|
||||||
|
|
||||||
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
def _get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||||
providers = self.fetch_model_providers()
|
providers = self.fetch_model_providers()
|
||||||
provider_entity = next((item for item in providers if item.provider == provider), None)
|
provider_entity = next((item for item in providers if item.provider == provider), None)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from core.plugin.impl.model import PluginModelClient
|
from core.plugin.impl.model import PluginModelClient
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
||||||
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
from graphon.model_runtime.model_providers.base.ai_model import AIModel
|
||||||
@ -118,7 +117,6 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
client=PluginModelClient(),
|
client=PluginModelClient(),
|
||||||
plugin_service=PluginService,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from core.plugin.entities.request import (
|
|||||||
TriggerSubscriptionResponse,
|
TriggerSubscriptionResponse,
|
||||||
)
|
)
|
||||||
from core.plugin.impl.trigger import PluginTriggerClient
|
from core.plugin.impl.trigger import PluginTriggerClient
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
|
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
|
||||||
from core.trigger.entities.entities import (
|
from core.trigger.entities.entities import (
|
||||||
EventEntity,
|
EventEntity,
|
||||||
@ -31,6 +30,7 @@ from core.trigger.entities.entities import (
|
|||||||
)
|
)
|
||||||
from core.trigger.errors import TriggerProviderCredentialValidationError
|
from core.trigger.errors import TriggerProviderCredentialValidationError
|
||||||
from models.provider_ids import TriggerProviderID
|
from models.provider_ids import TriggerProviderID
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -63,7 +63,7 @@ def _get_surface_form_token(
|
|||||||
*,
|
*,
|
||||||
surface: HumanInputSurface | None,
|
surface: HumanInputSurface | None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if surface == HumanInputSurface.SERVICE_API:
|
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
|
||||||
for recipient_type, token in recipients:
|
for recipient_type, token in recipients:
|
||||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||||
return token
|
return token
|
||||||
|
|||||||
@ -11,13 +11,15 @@ from models.human_input import RecipientType
|
|||||||
class HumanInputSurface(StrEnum):
|
class HumanInputSurface(StrEnum):
|
||||||
SERVICE_API = "service_api"
|
SERVICE_API = "service_api"
|
||||||
CONSOLE = "console"
|
CONSOLE = "console"
|
||||||
|
OPENAPI = "openapi"
|
||||||
|
|
||||||
|
|
||||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
|
||||||
# should only be able to act on end-user web forms, not internal console flows.
|
# should only be able to act on end-user web forms, not internal console flows.
|
||||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||||
|
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||||
}
|
}
|
||||||
|
|
||||||
# A single HITL form can have multiple recipient records; this shared priority
|
# A single HITL form can have multiple recipient records; this shared priority
|
||||||
|
|||||||
@ -45,6 +45,7 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
|||||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||||
|
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +162,8 @@ def create_spec_app() -> Flask:
|
|||||||
|
|
||||||
from controllers.console import bp as console_bp
|
from controllers.console import bp as console_bp
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
|
from controllers.openapi import bp as openapi_bp
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
from controllers.service_api import bp as service_api_bp
|
from controllers.service_api import bp as service_api_bp
|
||||||
from controllers.service_api import service_api_ns
|
from controllers.service_api import service_api_ns
|
||||||
from controllers.web import bp as web_bp
|
from controllers.web import bp as web_bp
|
||||||
@ -169,8 +172,9 @@ def create_spec_app() -> Flask:
|
|||||||
app.register_blueprint(console_bp)
|
app.register_blueprint(console_bp)
|
||||||
app.register_blueprint(web_bp)
|
app.register_blueprint(web_bp)
|
||||||
app.register_blueprint(service_api_bp)
|
app.register_blueprint(service_api_bp)
|
||||||
|
app.register_blueprint(openapi_bp)
|
||||||
|
|
||||||
for namespace in (console_ns, web_ns, service_api_ns):
|
for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
|
||||||
for api in namespace.apis:
|
for api in namespace.apis:
|
||||||
_materialize_inline_model_definitions(api)
|
_materialize_inline_model_definitions(api)
|
||||||
|
|
||||||
@ -201,6 +205,13 @@ def _registered_models(namespace: str) -> dict[str, object]:
|
|||||||
for api in service_api_ns.apis:
|
for api in service_api_ns.apis:
|
||||||
models.update(api.models)
|
models.update(api.models)
|
||||||
return models
|
return models
|
||||||
|
if namespace == "openapi":
|
||||||
|
from controllers.openapi import openapi_ns
|
||||||
|
|
||||||
|
models = dict(openapi_ns.models)
|
||||||
|
for api in openapi_ns.apis:
|
||||||
|
models.update(api.models)
|
||||||
|
return models
|
||||||
|
|
||||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
|
|||||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||||
|
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
|
||||||
|
OPENAPI_MAX_AGE_SECONDS: int = 600
|
||||||
|
|
||||||
|
|
||||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||||
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
|
|||||||
from controllers.files import bp as files_bp
|
from controllers.files import bp as files_bp
|
||||||
from controllers.inner_api import bp as inner_api_bp
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
from controllers.mcp import bp as mcp_bp
|
from controllers.mcp import bp as mcp_bp
|
||||||
|
from controllers.openapi import bp as openapi_bp
|
||||||
from controllers.service_api import bp as service_api_bp
|
from controllers.service_api import bp as service_api_bp
|
||||||
from controllers.trigger import bp as trigger_bp
|
from controllers.trigger import bp as trigger_bp
|
||||||
from controllers.web import bp as web_bp
|
from controllers.web import bp as web_bp
|
||||||
@ -41,6 +44,23 @@ def init_app(app: DifyApp):
|
|||||||
)
|
)
|
||||||
app.register_blueprint(service_api_bp)
|
app.register_blueprint(service_api_bp)
|
||||||
|
|
||||||
|
if dify_config.OPENAPI_ENABLED:
|
||||||
|
# User-scoped programmatic API. Default empty allowlist = same-origin
|
||||||
|
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
|
||||||
|
# integrations. supports_credentials so cookie-authed approve/deny
|
||||||
|
# work; cross-origin OPTIONS without an allowed origin will fail
|
||||||
|
# the same as on the console blueprint.
|
||||||
|
_apply_cors_once(
|
||||||
|
openapi_bp,
|
||||||
|
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=list(OPENAPI_HEADERS),
|
||||||
|
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||||
|
expose_headers=list(EXPOSED_HEADERS),
|
||||||
|
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||||
|
)
|
||||||
|
app.register_blueprint(openapi_bp)
|
||||||
|
|
||||||
_apply_cors_once(
|
_apply_cors_once(
|
||||||
web_bp,
|
web_bp,
|
||||||
resources={
|
resources={
|
||||||
|
|||||||
@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||||
"schedule": crontab(minute="0", hour="0"),
|
"schedule": crontab(minute="0", hour="0"),
|
||||||
}
|
}
|
||||||
|
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||||
|
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||||
|
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||||
|
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||||
|
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||||
|
}
|
||||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||||
imports.append("schedule.workflow_schedule_task")
|
imports.append("schedule.workflow_schedule_task")
|
||||||
beat_schedule["workflow_schedule_task"] = {
|
beat_schedule["workflow_schedule_task"] = {
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
|
|||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.passport import PassportService
|
from libs.passport import PassportService
|
||||||
from libs.token import extract_access_token, extract_webapp_passport
|
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
|
||||||
from models import Account, Tenant, TenantAccountJoin
|
from models import Account, Tenant, TenantAccountJoin
|
||||||
from models.model import AppMCPServer, EndUser
|
from models.model import AppMCPServer, EndUser
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
|||||||
|
|
||||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
return logged_in_account
|
return logged_in_account
|
||||||
|
elif request.blueprint == "openapi":
|
||||||
|
# Account-branch device-flow approval routes (approve / deny /
|
||||||
|
# approval-context) sit under @login_required and authenticate via
|
||||||
|
# the console session cookie. Cookie-only on purpose — bearer
|
||||||
|
# tokens (dfoa_/dfoe_) live on the Authorization header and are
|
||||||
|
# validated by AppPipeline, not flask-login.
|
||||||
|
cookie_token = extract_console_cookie_token(request)
|
||||||
|
if not cookie_token:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
decoded = PassportService().verify(cookie_token)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
source = decoded.get("token_source")
|
||||||
|
if source or not user_id:
|
||||||
|
return None
|
||||||
|
return AccountService.load_logged_in_account(account_id=user_id)
|
||||||
elif request.blueprint == "web":
|
elif request.blueprint == "web":
|
||||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||||
|
|||||||
23
api/extensions/ext_oauth_bearer.py
Normal file
23
api/extensions/ext_oauth_bearer.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||||
|
and ext_redis (needs both factories).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from dify_app import DifyApp
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.oauth_bearer import build_and_bind
|
||||||
|
|
||||||
|
|
||||||
|
def is_enabled() -> bool:
|
||||||
|
return dify_config.ENABLE_OAUTH_BEARER
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(app: DifyApp) -> None:
|
||||||
|
# scoped_session isn't a context manager; request teardown closes it.
|
||||||
|
def session_factory():
|
||||||
|
return db.session
|
||||||
|
|
||||||
|
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||||
196
api/libs/device_flow_security.py
Normal file
196
api/libs/device_flow_security.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||||
|
cookie mint/verify/consume, and anti-framing headers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask import Blueprint
|
||||||
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
|
from libs import jws
|
||||||
|
from libs.token import is_secure
|
||||||
|
from services.feature_service import FeatureService, LicenseStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# enterprise_only decorator
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
||||||
|
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
||||||
|
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
||||||
|
|
||||||
|
|
||||||
|
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||||
|
responses don't consume the bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
settings = FeatureService.get_system_features()
|
||||||
|
if settings.license.status not in _EE_ENABLED_STATUSES:
|
||||||
|
raise NotFound()
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# approval_grant cookie
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||||
|
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
||||||
|
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||||
|
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||||
|
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||||
|
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ApprovalGrantClaims:
|
||||||
|
subject_email: str
|
||||||
|
subject_issuer: str
|
||||||
|
user_code: str
|
||||||
|
nonce: str
|
||||||
|
csrf_token: str
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
def mint_approval_grant(
|
||||||
|
*,
|
||||||
|
keyset: jws.KeySet,
|
||||||
|
iss: str,
|
||||||
|
subject_email: str,
|
||||||
|
subject_issuer: str,
|
||||||
|
user_code: str,
|
||||||
|
) -> tuple[str, ApprovalGrantClaims]:
|
||||||
|
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||||
|
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||||
|
nonce = _random_opaque()
|
||||||
|
csrf_token = _random_opaque()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"iss": iss,
|
||||||
|
"subject_email": subject_email,
|
||||||
|
"subject_issuer": subject_issuer,
|
||||||
|
"user_code": user_code,
|
||||||
|
"nonce": nonce,
|
||||||
|
"csrf_token": csrf_token,
|
||||||
|
}
|
||||||
|
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||||
|
|
||||||
|
return token, ApprovalGrantClaims(
|
||||||
|
subject_email=subject_email,
|
||||||
|
subject_issuer=subject_issuer,
|
||||||
|
user_code=user_code,
|
||||||
|
nonce=nonce,
|
||||||
|
csrf_token=csrf_token,
|
||||||
|
expires_at=exp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||||
|
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||||
|
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||||
|
return ApprovalGrantClaims(
|
||||||
|
subject_email=data["subject_email"],
|
||||||
|
subject_issuer=data["subject_issuer"],
|
||||||
|
user_code=data["user_code"],
|
||||||
|
nonce=data["nonce"],
|
||||||
|
csrf_token=data["csrf_token"],
|
||||||
|
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||||
|
if not nonce:
|
||||||
|
return False
|
||||||
|
return bool(
|
||||||
|
redis_client.set(
|
||||||
|
NONCE_KEY_FMT.format(nonce=nonce),
|
||||||
|
"1",
|
||||||
|
nx=True,
|
||||||
|
ex=NONCE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||||
|
if not nonce:
|
||||||
|
return False
|
||||||
|
return bool(
|
||||||
|
redis_client.set(
|
||||||
|
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||||
|
"1",
|
||||||
|
nx=True,
|
||||||
|
ex=NONCE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||||
|
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||||
|
silently drop the cookie.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||||
|
"value": value,
|
||||||
|
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||||
|
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||||
|
"secure": is_secure(),
|
||||||
|
"httponly": True,
|
||||||
|
"samesite": "Lax",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||||
|
return {
|
||||||
|
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||||
|
"value": "",
|
||||||
|
"max_age": 0,
|
||||||
|
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||||
|
"secure": is_secure(),
|
||||||
|
"httponly": True,
|
||||||
|
"samesite": "Lax",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _random_opaque() -> str:
|
||||||
|
return secrets.token_urlsafe(16)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Anti-framing headers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
_ANTI_FRAMING_HEADERS = {
|
||||||
|
"X-Frame-Options": "DENY",
|
||||||
|
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def attach_anti_framing(bp: Blueprint) -> None:
|
||||||
|
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||||
|
|
||||||
|
@bp.after_request
|
||||||
|
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||||
|
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||||
|
response.headers.setdefault(name, value)
|
||||||
|
return response
|
||||||
@ -76,6 +76,7 @@ def register_external_error_handlers(api: Api):
|
|||||||
|
|
||||||
def handle_value_error(e: ValueError):
|
def handle_value_error(e: ValueError):
|
||||||
got_request_exception.send(current_app, exception=e)
|
got_request_exception.send(current_app, exception=e)
|
||||||
|
current_app.logger.exception("value_error in request handler")
|
||||||
status_code = 400
|
status_code = 400
|
||||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||||
return data, status_code
|
return data, status_code
|
||||||
|
|||||||
@ -595,3 +595,18 @@ class RateLimiter:
|
|||||||
|
|
||||||
self._redis_client.zadd(key, {member: current_time})
|
self._redis_client.zadd(key, {member: current_time})
|
||||||
self._redis_client.expire(key, self.time_window * 2)
|
self._redis_client.expire(key, self.time_window * 2)
|
||||||
|
|
||||||
|
def seconds_until_available(self, email: str) -> int:
|
||||||
|
"""Seconds until the oldest in-window entry expires, freeing a slot.
|
||||||
|
|
||||||
|
Defensive floor of 1 second. Caller should only invoke this after
|
||||||
|
is_rate_limited() returned True.
|
||||||
|
"""
|
||||||
|
key = self._get_key(email)
|
||||||
|
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
|
||||||
|
if not oldest:
|
||||||
|
return 1
|
||||||
|
_member, score = oldest[0]
|
||||||
|
free_at = int(score) + self.time_window
|
||||||
|
remaining = free_at - int(time.time())
|
||||||
|
return max(remaining, 1)
|
||||||
|
|||||||
108
api/libs/jws.py
Normal file
108
api/libs/jws.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
||||||
|
state envelope, external subject assertion, and approval-grant cookie —
|
||||||
|
all three share one key-set so api ↔ enterprise can verify each other.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||||
|
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
||||||
|
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
||||||
|
|
||||||
|
ACTIVE_KID_V1 = "dify-shared-v1"
|
||||||
|
|
||||||
|
|
||||||
|
class KeySetError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KeySet:
|
||||||
|
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
||||||
|
|
||||||
|
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
||||||
|
if active_kid not in entries:
|
||||||
|
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
||||||
|
if not entries[active_kid]:
|
||||||
|
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
||||||
|
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
||||||
|
self._active_kid = active_kid
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_shared_secret(cls) -> KeySet:
|
||||||
|
secret = dify_config.SECRET_KEY
|
||||||
|
if not secret:
|
||||||
|
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||||
|
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
||||||
|
return cls(entries, active_kid)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_kid(self) -> str:
|
||||||
|
return self._active_kid
|
||||||
|
|
||||||
|
def lookup(self, kid: str) -> bytes | None:
|
||||||
|
return self._entries.get(kid)
|
||||||
|
|
||||||
|
|
||||||
|
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
||||||
|
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
||||||
|
if "aud" in payload or "iat" in payload or "exp" in payload:
|
||||||
|
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
||||||
|
if ttl_seconds <= 0:
|
||||||
|
raise ValueError("ttl_seconds must be positive")
|
||||||
|
|
||||||
|
kid = keyset.active_kid
|
||||||
|
secret = keyset.lookup(kid)
|
||||||
|
if secret is None:
|
||||||
|
raise KeySetError(f"active kid {kid!r} lookup miss")
|
||||||
|
|
||||||
|
iat = datetime.now(UTC)
|
||||||
|
exp = iat + timedelta(seconds=ttl_seconds)
|
||||||
|
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
||||||
|
return jwt.encode(
|
||||||
|
claims,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"kid": kid, "typ": "JWT"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
||||||
|
"""Unknown kid is rejected — never fall back to the active kid, since
|
||||||
|
a past kid value would otherwise be forgeable by anyone who saw it.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
header = jwt.get_unverified_header(token)
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
raise VerifyError(f"decode header: {e}") from e
|
||||||
|
kid = header.get("kid")
|
||||||
|
if not kid:
|
||||||
|
raise VerifyError("no kid in header")
|
||||||
|
secret = keyset.lookup(kid)
|
||||||
|
if secret is None:
|
||||||
|
raise VerifyError(f"unknown kid {kid!r}")
|
||||||
|
try:
|
||||||
|
return jwt.decode(
|
||||||
|
token,
|
||||||
|
secret,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
audience=expected_aud,
|
||||||
|
)
|
||||||
|
except jwt.ExpiredSignatureError as e:
|
||||||
|
raise VerifyError("token expired") from e
|
||||||
|
except jwt.InvalidAudienceError as e:
|
||||||
|
raise VerifyError("aud mismatch") from e
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
raise VerifyError(f"decode: {e}") from e
|
||||||
685
api/libs/oauth_bearer.py
Normal file
685
api/libs/oauth_bearer.py
Normal file
@ -0,0 +1,685 @@
|
|||||||
|
"""OAuth bearer primitives.
|
||||||
|
|
||||||
|
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||||
|
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||||
|
Authenticator + validate_bearer stay untouched.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Literal, ParamSpec, Protocol, TypeVar
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.rate_limit import enforce_bearer_rate_limit
|
||||||
|
from models import Account, OAuthAccessToken, TenantAccountJoin
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Contract — types, enums, protocols
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class SubjectType(StrEnum):
|
||||||
|
ACCOUNT = "account"
|
||||||
|
EXTERNAL_SSO = "external_sso"
|
||||||
|
|
||||||
|
|
||||||
|
class Scope(StrEnum):
|
||||||
|
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||||
|
|
||||||
|
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
|
||||||
|
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
|
||||||
|
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
FULL = "full"
|
||||||
|
APPS_READ = "apps:read"
|
||||||
|
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
||||||
|
APPS_RUN = "apps:run"
|
||||||
|
|
||||||
|
|
||||||
|
class Accepts(StrEnum):
|
||||||
|
"""Subject types a route is willing to accept as caller."""
|
||||||
|
|
||||||
|
USER_ACCOUNT = "user_account"
|
||||||
|
USER_EXT_SSO = "user_ext_sso"
|
||||||
|
|
||||||
|
|
||||||
|
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
||||||
|
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
|
||||||
|
|
||||||
|
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||||
|
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
||||||
|
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class AuthContext:
|
||||||
|
"""Per-request identity published via :data:`_auth_ctx_var`
|
||||||
|
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
|
||||||
|
``subject_type`` / ``source`` come from the TokenKind, not the DB —
|
||||||
|
corrupt rows can't elevate scope.
|
||||||
|
|
||||||
|
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||||
|
authenticate time. Per-request mutations write through to Redis via
|
||||||
|
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
|
||||||
|
"""
|
||||||
|
|
||||||
|
subject_type: SubjectType
|
||||||
|
subject_email: str | None
|
||||||
|
subject_issuer: str | None
|
||||||
|
account_id: uuid.UUID | None
|
||||||
|
client_id: str | None
|
||||||
|
scopes: frozenset[Scope]
|
||||||
|
token_id: uuid.UUID
|
||||||
|
source: str
|
||||||
|
expires_at: datetime | None
|
||||||
|
token_hash: str
|
||||||
|
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
_auth_ctx_var: ContextVar[AuthContext] = ContextVar("openapi_auth_ctx")
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_ctx(ctx: AuthContext) -> Token[AuthContext]:
|
||||||
|
return _auth_ctx_var.set(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_auth_ctx(token: Token[AuthContext]) -> None:
|
||||||
|
_auth_ctx_var.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_ctx() -> AuthContext:
|
||||||
|
return _auth_ctx_var.get()
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_auth_ctx() -> AuthContext | None:
|
||||||
|
return _auth_ctx_var.get(None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class ResolvedRow:
|
||||||
|
subject_email: str | None
|
||||||
|
subject_issuer: str | None
|
||||||
|
account_id: uuid.UUID | None
|
||||||
|
client_id: str | None
|
||||||
|
token_id: uuid.UUID
|
||||||
|
expires_at: datetime | None
|
||||||
|
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_cache(self) -> dict:
|
||||||
|
return {
|
||||||
|
"subject_email": self.subject_email,
|
||||||
|
"subject_issuer": self.subject_issuer,
|
||||||
|
"account_id": str(self.account_id) if self.account_id else None,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"token_id": str(self.token_id),
|
||||||
|
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||||
|
"verified_tenants": dict(self.verified_tenants),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cache(cls, data: dict) -> ResolvedRow:
|
||||||
|
return cls(
|
||||||
|
subject_email=data["subject_email"],
|
||||||
|
subject_issuer=data["subject_issuer"],
|
||||||
|
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
||||||
|
client_id=data.get("client_id"),
|
||||||
|
token_id=uuid.UUID(data["token_id"]),
|
||||||
|
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
||||||
|
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
|
||||||
|
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
|
||||||
|
|
||||||
|
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
|
||||||
|
on all live deployments (60s TTL → safe to drop one release after rollout).
|
||||||
|
"""
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
return {}
|
||||||
|
out: dict[str, bool] = {}
|
||||||
|
for k, v in raw.items():
|
||||||
|
if isinstance(v, bool):
|
||||||
|
out[k] = v
|
||||||
|
elif v == "ok":
|
||||||
|
out[k] = True
|
||||||
|
elif v == "denied":
|
||||||
|
out[k] = False
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Resolver(Protocol):
|
||||||
|
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class TokenKind:
|
||||||
|
prefix: str
|
||||||
|
subject_type: SubjectType
|
||||||
|
scopes: frozenset[Scope]
|
||||||
|
source: str
|
||||||
|
resolver: Resolver
|
||||||
|
|
||||||
|
def matches(self, token: str) -> bool:
|
||||||
|
return token.startswith(self.prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MintProfile:
|
||||||
|
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
|
||||||
|
|
||||||
|
Consumers:
|
||||||
|
- ``build_registry`` reads scopes here so the resolve-time TokenKind
|
||||||
|
cannot drift from the mint-time intent.
|
||||||
|
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
|
||||||
|
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
|
||||||
|
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
|
||||||
|
the (subject_type, prefix, scopes) triple a caller intends to mint
|
||||||
|
against this table — a caller that assembles its own scope set
|
||||||
|
from a non-canonical source will fail closed at approve time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
subject_type: SubjectType
|
||||||
|
prefix: str
|
||||||
|
scopes: frozenset[Scope]
|
||||||
|
|
||||||
|
|
||||||
|
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
|
||||||
|
SubjectType.ACCOUNT: MintProfile(
|
||||||
|
subject_type=SubjectType.ACCOUNT,
|
||||||
|
prefix="dfoa_",
|
||||||
|
scopes=frozenset({Scope.FULL}),
|
||||||
|
),
|
||||||
|
SubjectType.EXTERNAL_SSO: MintProfile(
|
||||||
|
subject_type=SubjectType.EXTERNAL_SSO,
|
||||||
|
prefix="dfoe_",
|
||||||
|
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidBearerError(Exception):
|
||||||
|
"""Token missing, unknown prefix, or no live row."""
|
||||||
|
|
||||||
|
|
||||||
|
class TokenExpiredError(Exception):
|
||||||
|
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Registry
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TokenKindRegistry:
|
||||||
|
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
||||||
|
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
||||||
|
prefixes = [k.prefix for k in self._kinds]
|
||||||
|
if len(set(prefixes)) != len(prefixes):
|
||||||
|
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
||||||
|
|
||||||
|
def find(self, token: str) -> TokenKind | None:
|
||||||
|
for k in self._kinds:
|
||||||
|
if k.matches(token):
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
|
def kinds(self) -> tuple[TokenKind, ...]:
|
||||||
|
return self._kinds
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Authenticator
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_hex(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class BearerAuthenticator:
|
||||||
|
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||||
|
self._registry = registry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registry(self) -> TokenKindRegistry:
|
||||||
|
return self._registry
|
||||||
|
|
||||||
|
def authenticate(self, token: str) -> AuthContext:
|
||||||
|
"""Identity + per-token rate limit (single source).
|
||||||
|
|
||||||
|
Both the openapi pipeline (`BearerCheck`) and the decorator
|
||||||
|
(`validate_bearer`) call this — rate-limit fires exactly once per
|
||||||
|
request regardless of which path hosts the route.
|
||||||
|
"""
|
||||||
|
kind = self._registry.find(token)
|
||||||
|
if kind is None:
|
||||||
|
raise InvalidBearerError("unknown token prefix")
|
||||||
|
token_hash = sha256_hex(token)
|
||||||
|
row = kind.resolver.resolve(token_hash)
|
||||||
|
if row is None:
|
||||||
|
raise InvalidBearerError("token unknown or revoked")
|
||||||
|
enforce_bearer_rate_limit(token_hash)
|
||||||
|
return AuthContext(
|
||||||
|
subject_type=kind.subject_type,
|
||||||
|
subject_email=row.subject_email,
|
||||||
|
subject_issuer=row.subject_issuer,
|
||||||
|
account_id=row.account_id,
|
||||||
|
client_id=row.client_id,
|
||||||
|
scopes=kind.scopes,
|
||||||
|
token_id=row.token_id,
|
||||||
|
source=kind.source,
|
||||||
|
expires_at=row.expires_at,
|
||||||
|
token_hash=token_hash,
|
||||||
|
verified_tenants=dict(row.verified_tenants),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# OAuth access token resolver (PAT resolver would be a sibling class)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
||||||
|
POSITIVE_TTL_SECONDS = 60
|
||||||
|
NEGATIVE_TTL_SECONDS = 10
|
||||||
|
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
||||||
|
|
||||||
|
ScopeVariant = Literal["account", "external_sso"]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccessTokenResolver:
|
||||||
|
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
||||||
|
sharing DB + cache plumbing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory,
|
||||||
|
redis_client,
|
||||||
|
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||||
|
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
self.session_factory = session_factory
|
||||||
|
self._redis = redis_client
|
||||||
|
self._positive_ttl = positive_ttl
|
||||||
|
self._negative_ttl = negative_ttl
|
||||||
|
|
||||||
|
def for_account(self) -> Resolver:
|
||||||
|
return _VariantResolver(self, variant="account")
|
||||||
|
|
||||||
|
def for_external_sso(self) -> Resolver:
|
||||||
|
return _VariantResolver(self, variant="external_sso")
|
||||||
|
|
||||||
|
def _cache_key(self, token_hash: str) -> str:
|
||||||
|
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||||
|
|
||||||
|
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||||
|
raw = self._redis.get(self._cache_key(token_hash))
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||||
|
if text == "invalid":
|
||||||
|
return "invalid"
|
||||||
|
try:
|
||||||
|
return ResolvedRow.from_cache(json.loads(text))
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||||
|
self._redis.setex(
|
||||||
|
self._cache_key(token_hash),
|
||||||
|
self._positive_ttl,
|
||||||
|
json.dumps(row.to_cache()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def cache_set_negative(self, token_hash: str) -> None:
|
||||||
|
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||||
|
|
||||||
|
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
||||||
|
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||||
|
replays are idempotent.
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
update(OAuthAccessToken)
|
||||||
|
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
||||||
|
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||||
|
)
|
||||||
|
result = session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
if result.rowcount == 1: # type: ignore
|
||||||
|
logger.warning(
|
||||||
|
"audit: %s token_id=%s",
|
||||||
|
AUDIT_OAUTH_EXPIRED,
|
||||||
|
row_id,
|
||||||
|
extra={"audit": True, "token_id": str(row_id)},
|
||||||
|
)
|
||||||
|
self._redis.delete(self._cache_key(token_hash))
|
||||||
|
self.cache_set_negative(token_hash)
|
||||||
|
|
||||||
|
|
||||||
|
class _VariantResolver:
|
||||||
|
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||||
|
self._parent = parent
|
||||||
|
self._variant = variant
|
||||||
|
|
||||||
|
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||||
|
cached = self._parent.cache_get(token_hash)
|
||||||
|
if cached == "invalid":
|
||||||
|
return None
|
||||||
|
if cached is not None and not isinstance(cached, str):
|
||||||
|
if not self._matches_variant(cached):
|
||||||
|
return None
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Flask-SQLAlchemy's scoped_session is request-bound and not a
|
||||||
|
# context manager; use it directly.
|
||||||
|
session = self._parent.session_factory()
|
||||||
|
row = self._load_from_db(session, token_hash)
|
||||||
|
if row is None:
|
||||||
|
self._parent.cache_set_negative(token_hash)
|
||||||
|
return None
|
||||||
|
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
if row.expires_at is not None and row.expires_at <= now:
|
||||||
|
self._parent.hard_expire(session, row.id, token_hash)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self._matches_variant_model(row):
|
||||||
|
logger.error(
|
||||||
|
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||||
|
row.id,
|
||||||
|
row.prefix,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved = ResolvedRow(
|
||||||
|
subject_email=row.subject_email,
|
||||||
|
subject_issuer=row.subject_issuer,
|
||||||
|
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
||||||
|
client_id=row.client_id,
|
||||||
|
token_id=uuid.UUID(str(row.id)),
|
||||||
|
expires_at=row.expires_at,
|
||||||
|
)
|
||||||
|
self._parent.cache_set_positive(token_hash, resolved)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||||
|
has_account = row.account_id is not None
|
||||||
|
if self._variant == "account":
|
||||||
|
return has_account
|
||||||
|
return not has_account
|
||||||
|
|
||||||
|
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
||||||
|
has_account = row.account_id is not None
|
||||||
|
if self._variant == "account":
|
||||||
|
return has_account and row.prefix == "dfoa_"
|
||||||
|
return (not has_account) and row.prefix == "dfoe_"
|
||||||
|
|
||||||
|
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
||||||
|
return (
|
||||||
|
session.query(OAuthAccessToken)
|
||||||
|
.filter(
|
||||||
|
OAuthAccessToken.token_hash == token_hash,
|
||||||
|
OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Layer 0 — workspace membership cache + helper
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
|
||||||
|
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
|
||||||
|
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
|
||||||
|
rebuilds via authenticate() and re-runs Layer 0.
|
||||||
|
"""
|
||||||
|
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||||
|
raw = redis_client.get(cache_key)
|
||||||
|
if raw is None:
|
||||||
|
return
|
||||||
|
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||||
|
if text == "invalid":
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
return
|
||||||
|
ttl = redis_client.ttl(cache_key)
|
||||||
|
if ttl <= 0:
|
||||||
|
return
|
||||||
|
data.setdefault("verified_tenants", {})[tenant_id] = verdict
|
||||||
|
redis_client.setex(cache_key, ttl, json.dumps(data))
|
||||||
|
|
||||||
|
|
||||||
|
def check_workspace_membership(
|
||||||
|
*,
|
||||||
|
account_id: uuid.UUID | str,
|
||||||
|
tenant_id: str,
|
||||||
|
token_hash: str,
|
||||||
|
cached_verdicts: dict[str, bool],
|
||||||
|
) -> None:
|
||||||
|
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||||
|
|
||||||
|
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
|
||||||
|
inline helper (`require_workspace_member`). Caller is responsible for
|
||||||
|
short-circuiting on EE / SSO subjects before invoking — this function
|
||||||
|
runs the membership + active-status checks unconditionally.
|
||||||
|
"""
|
||||||
|
cached = cached_verdicts.get(tenant_id)
|
||||||
|
if cached is True:
|
||||||
|
return
|
||||||
|
if cached is False:
|
||||||
|
raise Forbidden("workspace_membership_revoked")
|
||||||
|
|
||||||
|
join = db.session.execute(
|
||||||
|
select(TenantAccountJoin.id).where(
|
||||||
|
TenantAccountJoin.account_id == account_id,
|
||||||
|
TenantAccountJoin.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if join is None:
|
||||||
|
record_layer0_verdict(token_hash, tenant_id, False)
|
||||||
|
raise Forbidden("workspace_membership_revoked")
|
||||||
|
|
||||||
|
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
|
||||||
|
if status != "active":
|
||||||
|
record_layer0_verdict(token_hash, tenant_id, False)
|
||||||
|
raise Forbidden("workspace_membership_revoked")
|
||||||
|
|
||||||
|
record_layer0_verdict(token_hash, tenant_id, True)
|
||||||
|
|
||||||
|
|
||||||
|
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||||
|
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
|
||||||
|
|
||||||
|
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
|
||||||
|
(no `tenant_account_joins` row by definition).
|
||||||
|
"""
|
||||||
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
|
return
|
||||||
|
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||||
|
return
|
||||||
|
check_workspace_membership(
|
||||||
|
account_id=ctx.account_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
token_hash=ctx.token_hash,
|
||||||
|
cached_verdicts=ctx.verified_tenants,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Decorator — route-level bearer gate
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
_authenticator: BearerAuthenticator | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
||||||
|
global _authenticator
|
||||||
|
_authenticator = authenticator
|
||||||
|
|
||||||
|
|
||||||
|
def get_authenticator() -> BearerAuthenticator:
|
||||||
|
if _authenticator is None:
|
||||||
|
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
||||||
|
return _authenticator
|
||||||
|
|
||||||
|
|
||||||
|
def extract_bearer(req) -> str | None:
|
||||||
|
"""Pull the bearer token out of an HTTP request's Authorization header.
|
||||||
|
|
||||||
|
Used by both attachment paths (the ``validate_bearer`` decorator and the
|
||||||
|
openapi ``Pipeline.guard``) so the parsing rule lives in one place. Pipeline
|
||||||
|
callers extract once at the boundary and pass the token through ``Context``
|
||||||
|
so steps stay independent of the request object.
|
||||||
|
"""
|
||||||
|
header = req.headers.get("Authorization", "")
|
||||||
|
scheme, _, value = header.partition(" ")
|
||||||
|
if scheme.lower() != "bearer" or not value:
|
||||||
|
return None
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
_DP = ParamSpec("_DP")
|
||||||
|
_DR = TypeVar("_DR")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
||||||
|
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||||
|
|
||||||
|
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
||||||
|
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
|
||||||
|
and are rejected here as the wrong auth scheme for this surface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
||||||
|
token = extract_bearer(request)
|
||||||
|
if token is None:
|
||||||
|
raise Unauthorized("missing bearer token")
|
||||||
|
|
||||||
|
if _authenticator is None:
|
||||||
|
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ctx = get_authenticator().authenticate(token)
|
||||||
|
except InvalidBearerError as e:
|
||||||
|
raise Unauthorized(str(e))
|
||||||
|
|
||||||
|
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||||
|
raise Forbidden("token subject type not accepted here")
|
||||||
|
|
||||||
|
# Try/finally pairing — the WSGI worker thread is reused
|
||||||
|
# across requests, so a leaked ContextVar would publish the
|
||||||
|
# previous caller's identity to the next request.
|
||||||
|
reset_token = set_auth_ctx(ctx)
|
||||||
|
try:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
reset_auth_ctx(reset_token)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||||
|
without the authenticator, so fail fast instead of approving silently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||||
|
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def require_scope(scope: Scope) -> Callable:
|
||||||
|
"""Route-level scope gate — must run AFTER validate_bearer so that
|
||||||
|
the auth ContextVar is set. Raises ``Forbidden('insufficient_scope: <scope>')``
|
||||||
|
when the bearer lacks both the requested scope and ``Scope.FULL``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrap(fn: Callable) -> Callable:
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
ctx = try_get_auth_ctx()
|
||||||
|
if ctx is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
||||||
|
)
|
||||||
|
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
|
||||||
|
raise Forbidden(f"insufficient_scope: {scope}")
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Wiring — called once from the app factory
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||||
|
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||||
|
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||||
|
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||||
|
return TokenKindRegistry(
|
||||||
|
[
|
||||||
|
TokenKind(
|
||||||
|
prefix=account.prefix,
|
||||||
|
subject_type=account.subject_type,
|
||||||
|
scopes=account.scopes,
|
||||||
|
source="oauth_account",
|
||||||
|
resolver=oauth.for_account(),
|
||||||
|
),
|
||||||
|
TokenKind(
|
||||||
|
prefix=external.prefix,
|
||||||
|
subject_type=external.subject_type,
|
||||||
|
scopes=external.scopes,
|
||||||
|
source="oauth_external_sso",
|
||||||
|
resolver=oauth.for_external_sso(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||||
|
registry = build_registry(session_factory, redis_client)
|
||||||
|
auth = BearerAuthenticator(registry)
|
||||||
|
bind_authenticator(auth)
|
||||||
|
return auth
|
||||||
147
api/libs/rate_limit.py
Normal file
147
api/libs/rate_limit.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
||||||
|
window Redis ZSET). Apply after auth decorators so account/email/token-id
|
||||||
|
scopes can read the openapi auth ContextVar (see
|
||||||
|
:func:`libs.oauth_bearer.try_get_auth_ctx`). Use :func:`enforce` when the
|
||||||
|
bucket key is computed in-handler. RFC-8628 ``slow_down`` is inline — its
|
||||||
|
response shape isn't generic 429.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import timedelta
|
||||||
|
from enum import StrEnum
|
||||||
|
from functools import wraps
|
||||||
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from flask import jsonify, make_response, request, session
|
||||||
|
from werkzeug.exceptions import TooManyRequests
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from libs.helper import RateLimiter, extract_remote_ip
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitScope(StrEnum):
|
||||||
|
IP = "ip"
|
||||||
|
SESSION = "session"
|
||||||
|
ACCOUNT = "account"
|
||||||
|
SUBJECT_EMAIL = "subject_email"
|
||||||
|
TOKEN_ID = "token_id"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class RateLimit:
|
||||||
|
limit: int
|
||||||
|
window: timedelta
|
||||||
|
scopes: tuple[RateLimitScope, ...]
|
||||||
|
|
||||||
|
|
||||||
|
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||||
|
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||||
|
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||||
|
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
||||||
|
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
||||||
|
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
||||||
|
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||||
|
LIMIT_BEARER_PER_TOKEN = RateLimit(
|
||||||
|
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
|
||||||
|
window=timedelta(minutes=1),
|
||||||
|
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _one_key(scope: RateLimitScope) -> str:
|
||||||
|
match scope:
|
||||||
|
case RateLimitScope.IP:
|
||||||
|
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
||||||
|
case RateLimitScope.SESSION:
|
||||||
|
return f"session:{session.get('_id', 'anon')}"
|
||||||
|
case RateLimitScope.ACCOUNT:
|
||||||
|
from libs.oauth_bearer import try_get_auth_ctx
|
||||||
|
|
||||||
|
ctx = try_get_auth_ctx()
|
||||||
|
if ctx and ctx.account_id:
|
||||||
|
return f"account:{ctx.account_id}"
|
||||||
|
return "account:anon"
|
||||||
|
case RateLimitScope.SUBJECT_EMAIL:
|
||||||
|
from libs.oauth_bearer import try_get_auth_ctx
|
||||||
|
|
||||||
|
ctx = try_get_auth_ctx()
|
||||||
|
if ctx and ctx.subject_email:
|
||||||
|
return f"subject:{ctx.subject_email}"
|
||||||
|
return "subject:anon"
|
||||||
|
case RateLimitScope.TOKEN_ID:
|
||||||
|
from libs.oauth_bearer import try_get_auth_ctx
|
||||||
|
|
||||||
|
ctx = try_get_auth_ctx()
|
||||||
|
if ctx and ctx.token_id:
|
||||||
|
return f"token:{ctx.token_id}"
|
||||||
|
return "token:anon"
|
||||||
|
|
||||||
|
|
||||||
|
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||||
|
return "|".join(_one_key(s) for s in scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||||
|
return "rl:" + "+".join(s.value for s in scopes)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||||
|
return RateLimiter(
|
||||||
|
prefix=_limiter_prefix(spec.scopes),
|
||||||
|
max_attempts=spec.limit,
|
||||||
|
time_window=int(spec.window.total_seconds()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||||
|
"""Apply after auth decorators that the scopes read from."""
|
||||||
|
limiter = _build_limiter(spec)
|
||||||
|
|
||||||
|
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
|
key = _composite_key(spec.scopes)
|
||||||
|
if limiter.is_rate_limited(key):
|
||||||
|
raise TooManyRequests("rate_limited")
|
||||||
|
limiter.increment_rate_limit(key)
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
return wrap
|
||||||
|
|
||||||
|
|
||||||
|
def enforce(spec: RateLimit, *, key: str) -> None:
|
||||||
|
"""Imperative form — caller composes the bucket key to match scope
|
||||||
|
semantics (the key is opaque here).
|
||||||
|
"""
|
||||||
|
limiter = _build_limiter(spec)
|
||||||
|
if limiter.is_rate_limited(key):
|
||||||
|
raise TooManyRequests("rate_limited")
|
||||||
|
limiter.increment_rate_limit(key)
|
||||||
|
|
||||||
|
|
||||||
|
def enforce_bearer_rate_limit(token_hash: str) -> None:
|
||||||
|
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
|
||||||
|
|
||||||
|
Bucket key = ``token:<sha256_hex>`` so the same token shares one
|
||||||
|
bucket across api replicas (Redis-backed sliding window).
|
||||||
|
"""
|
||||||
|
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
|
||||||
|
key = f"token:{token_hash}"
|
||||||
|
if limiter.is_rate_limited(key):
|
||||||
|
retry_after = limiter.seconds_until_available(key)
|
||||||
|
response = make_response(
|
||||||
|
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
|
||||||
|
429,
|
||||||
|
)
|
||||||
|
response.headers["Retry-After"] = str(retry_after)
|
||||||
|
raise TooManyRequests(response=response)
|
||||||
|
limiter.increment_rate_limit(key)
|
||||||
@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
|||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
def extract_access_token(request: Request) -> str | None:
|
def extract_console_cookie_token(request: Request) -> str | None:
|
||||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
|
||||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
approval routes, which must not fall through to the Authorization header
|
||||||
|
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
|
||||||
|
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||||
|
|
||||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
|
||||||
|
def extract_access_token(request: Request) -> str | None:
|
||||||
|
return extract_console_cookie_token(request) or _try_extract_from_header(request)
|
||||||
|
|
||||||
|
|
||||||
def extract_webapp_access_token(request: Request) -> str | None:
|
def extract_webapp_access_token(request: Request) -> str | None:
|
||||||
|
|||||||
@ -0,0 +1,128 @@
|
|||||||
|
"""add oauth_access_tokens table
|
||||||
|
|
||||||
|
Revision ID: d4a5e1f3c9b7
|
||||||
|
Revises: f8b6b7e9c421
|
||||||
|
Create Date: 2026-05-22 17:00:00.000000
|
||||||
|
|
||||||
|
Table stores user-level OAuth bearer tokens minted via the device-flow grant
|
||||||
|
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
|
||||||
|
table not added in this migration.
|
||||||
|
|
||||||
|
Cross-dialect notes:
|
||||||
|
- UUID columns use ``models.types.StringUUID`` (UUID on PG, CHAR(36) on
|
||||||
|
MySQL). The application generates ids via ``libs.uuid_utils.uuidv7``;
|
||||||
|
on PG we additionally set a ``server_default`` so direct SQL inserts
|
||||||
|
remain valid.
|
||||||
|
- Indexed text columns are bounded ``VARCHAR(255)`` because MySQL cannot
|
||||||
|
index ``TEXT`` without an explicit prefix length.
|
||||||
|
- ``postgresql_where=`` is silently dropped by SQLAlchemy on MySQL, so the
|
||||||
|
partial-index filters degrade to plain indexes — semantically a
|
||||||
|
superset, still correct for lookup. The composite unique index on
|
||||||
|
``(subject_email, subject_issuer, client_id, device_label)`` enforces
|
||||||
|
uniqueness across both dialects (NULLs are distinct in both, matching
|
||||||
|
the rotate-in-place contract documented on ``OAuthAccessToken``).
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
import models
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "d4a5e1f3c9b7"
|
||||||
|
down_revision = "f8b6b7e9c421"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pg() -> bool:
|
||||||
|
return op.get_bind().dialect.name == "postgresql"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
id_kwargs: dict = {"nullable": False, "primary_key": True}
|
||||||
|
if _is_pg():
|
||||||
|
# Match the convention established by 2026_05_19_1000 (uuidv7()).
|
||||||
|
id_kwargs["server_default"] = sa.text("uuidv7()")
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"oauth_access_tokens",
|
||||||
|
sa.Column("id", models.types.StringUUID(), **id_kwargs),
|
||||||
|
sa.Column("subject_email", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("subject_issuer", sa.String(length=255), nullable=True),
|
||||||
|
sa.Column("account_id", models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||||
|
sa.Column("device_label", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("prefix", sa.String(length=8), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.func.current_timestamp(),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["account_id"],
|
||||||
|
["accounts.id"],
|
||||||
|
name="fk_oauth_access_tokens_account_id",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Partial-index WHERE clauses are PG-only (SQLAlchemy drops the kwarg
|
||||||
|
# on MySQL → plain index, which is still correct for lookup).
|
||||||
|
op.create_index(
|
||||||
|
"idx_oauth_subject_email",
|
||||||
|
"oauth_access_tokens",
|
||||||
|
["subject_email"],
|
||||||
|
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"idx_oauth_account",
|
||||||
|
"oauth_access_tokens",
|
||||||
|
["account_id"],
|
||||||
|
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"idx_oauth_client",
|
||||||
|
"oauth_access_tokens",
|
||||||
|
["subject_email", "client_id"],
|
||||||
|
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"idx_oauth_token_hash",
|
||||||
|
"oauth_access_tokens",
|
||||||
|
["token_hash"],
|
||||||
|
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||||
|
)
|
||||||
|
# Rotate-in-place keyed on (subject, client, device). The app always
|
||||||
|
# writes a non-NULL subject_issuer (account flow uses a sentinel,
|
||||||
|
# external-SSO uses the verified IdP issuer); without that guarantee
|
||||||
|
# the composite key would never collide because both PG and MySQL
|
||||||
|
# treat NULLs as distinct in unique indices.
|
||||||
|
#
|
||||||
|
# ``mysql_length`` truncates each text column to 191 chars in the index
|
||||||
|
# — utf8mb4 makes the per-row index entry (191+191+64+191)*4 = 2548
|
||||||
|
# bytes, comfortably under InnoDB's 3072-byte index limit. Collisions
|
||||||
|
# on the 191-char prefix are vanishingly unlikely for real emails /
|
||||||
|
# OIDC issuers / device labels, and the app re-checks the full-row
|
||||||
|
# invariant before issuing a rotation.
|
||||||
|
op.create_index(
|
||||||
|
"uq_oauth_active_per_device",
|
||||||
|
"oauth_access_tokens",
|
||||||
|
["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||||
|
mysql_length={"subject_email": 191, "subject_issuer": 191, "device_label": 191},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
|
||||||
|
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
|
||||||
|
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
|
||||||
|
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
|
||||||
|
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
|
||||||
|
op.drop_table("oauth_access_tokens")
|
||||||
@ -86,7 +86,7 @@ from .model import (
|
|||||||
TrialApp,
|
TrialApp,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
)
|
)
|
||||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
|
||||||
from .provider import (
|
from .provider import (
|
||||||
LoadBalancingModelConfig,
|
LoadBalancingModelConfig,
|
||||||
Provider,
|
Provider,
|
||||||
@ -199,6 +199,7 @@ __all__ = [
|
|||||||
"MessageChain",
|
"MessageChain",
|
||||||
"MessageFeedback",
|
"MessageFeedback",
|
||||||
"MessageFile",
|
"MessageFile",
|
||||||
|
"OAuthAccessToken",
|
||||||
"OperationLog",
|
"OperationLog",
|
||||||
"PinnedConversation",
|
"PinnedConversation",
|
||||||
"Provider",
|
"Provider",
|
||||||
|
|||||||
@ -185,6 +185,7 @@ class InvokeFrom(StrEnum):
|
|||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
PUBLISHED_PIPELINE = "published"
|
PUBLISHED_PIPELINE = "published"
|
||||||
VALIDATION = "validation"
|
VALIDATION = "validation"
|
||||||
|
OPENAPI = "openapi"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "InvokeFrom":
|
def value_of(cls, value: str) -> "InvokeFrom":
|
||||||
@ -197,6 +198,7 @@ class InvokeFrom(StrEnum):
|
|||||||
InvokeFrom.EXPLORE: "explore_app",
|
InvokeFrom.EXPLORE: "explore_app",
|
||||||
InvokeFrom.TRIGGER: "trigger",
|
InvokeFrom.TRIGGER: "trigger",
|
||||||
InvokeFrom.SERVICE_API: "api",
|
InvokeFrom.SERVICE_API: "api",
|
||||||
|
InvokeFrom.OPENAPI: "openapi",
|
||||||
}
|
}
|
||||||
return source_mapping.get(self, "dev")
|
return source_mapping.get(self, "dev")
|
||||||
|
|
||||||
|
|||||||
@ -492,8 +492,8 @@ class App(Base):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
# get agent mode tools
|
# get agent mode tools
|
||||||
app_model_config = self.app_model_config
|
app_model_config = self.app_model_config
|
||||||
|
|||||||
@ -84,3 +84,39 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
|||||||
onupdate=func.current_timestamp(),
|
onupdate=func.current_timestamp(),
|
||||||
init=False,
|
init=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccessToken(TypeBase):
|
||||||
|
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
|
||||||
|
subject_issuer = "dify:account" sentinel); account_id NULL +
|
||||||
|
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
|
||||||
|
subject_issuer is non-NULL for all rows the app writes — Postgres
|
||||||
|
treats NULLs as distinct in unique indices, so the partial unique
|
||||||
|
index on (subject_email, subject_issuer, client_id, device_label)
|
||||||
|
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "oauth_access_tokens"
|
||||||
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||||
|
)
|
||||||
|
# Indexed text columns are bounded VARCHARs so the schema is portable
|
||||||
|
# across PostgreSQL and MySQL (MySQL cannot index TEXT without a prefix
|
||||||
|
# length). 255 chars accommodates RFC-compliant emails and typical
|
||||||
|
# OIDC issuer URLs / device labels.
|
||||||
|
subject_email: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
|
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
|
||||||
|
device_label: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||||
|
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||||
|
subject_issuer: Mapped[str | None] = mapped_column(sa.String(255), nullable=True, default=None)
|
||||||
|
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||||
|
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||||
|
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||||
|
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||||
|
)
|
||||||
|
|||||||
@ -1209,6 +1209,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
|||||||
SERVICE_API = "service-api"
|
SERVICE_API = "service-api"
|
||||||
WEB_APP = "web-app"
|
WEB_APP = "web-app"
|
||||||
INSTALLED_APP = "installed-app"
|
INSTALLED_APP = "installed-app"
|
||||||
|
OPENAPI = "openapi"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||||
|
|||||||
656
api/openapi/markdown/openapi-swagger.md
Normal file
656
api/openapi/markdown/openapi-swagger.md
Normal file
@ -0,0 +1,656 @@
|
|||||||
|
# OpenAPI
|
||||||
|
User-scoped programmatic API (bearer auth)
|
||||||
|
|
||||||
|
## Version: 1.0
|
||||||
|
|
||||||
|
### Security
|
||||||
|
**Bearer**
|
||||||
|
|
||||||
|
| apiKey | *API Key* |
|
||||||
|
| ------ | --------- |
|
||||||
|
| Description | Type: Bearer {your-api-key} |
|
||||||
|
| In | header |
|
||||||
|
| Name | Authorization |
|
||||||
|
|
||||||
|
---
|
||||||
|
## openapi
|
||||||
|
User-scoped operations
|
||||||
|
|
||||||
|
### /_health
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Success |
|
||||||
|
|
||||||
|
### /_version
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Server version | [ServerVersionResponse](#serverversionresponse) |
|
||||||
|
|
||||||
|
### /account
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Account info | [AccountResponse](#accountresponse) |
|
||||||
|
|
||||||
|
### /account/sessions
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Session list | [SessionListResponse](#sessionlistresponse) |
|
||||||
|
|
||||||
|
### /account/sessions/self
|
||||||
|
|
||||||
|
#### DELETE
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Session revoked | [RevokeResponse](#revokeresponse) |
|
||||||
|
|
||||||
|
### /account/sessions/{session_id}
|
||||||
|
|
||||||
|
#### DELETE
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| session_id | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Session revoked | [RevokeResponse](#revokeresponse) |
|
||||||
|
|
||||||
|
### /apps
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| limit | query | | No | integer |
|
||||||
|
| mode | query | | No | string |
|
||||||
|
| name | query | | No | string |
|
||||||
|
| page | query | | No | integer |
|
||||||
|
| tag | query | | No | string |
|
||||||
|
| workspace_id | query | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | App list | [AppListResponse](#applistresponse) |
|
||||||
|
|
||||||
|
### /apps/{app_id}/describe
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| fields | query | | No | [ string ] |
|
||||||
|
| workspace_id | query | | No | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | App description | [AppDescribeResponse](#appdescriberesponse) |
|
||||||
|
|
||||||
|
### /apps/{app_id}/files/upload
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Description
|
||||||
|
|
||||||
|
Upload a file to use as an input variable when running the app
|
||||||
|
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 201 | File uploaded successfully | [FileResponse](#fileresponse) |
|
||||||
|
| 400 | Bad request — no file or filename missing | |
|
||||||
|
| 401 | Unauthorized — invalid or expired bearer token | |
|
||||||
|
| 413 | File too large | |
|
||||||
|
| 415 | Unsupported file type or blocked extension | |
|
||||||
|
|
||||||
|
### /apps/{app_id}/form/human_input/{form_token}
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| form_token | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Form definition |
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| form_token | path | | Yes | string |
|
||||||
|
| payload | body | | Yes | [HumanInputFormSubmitPayload](#humaninputformsubmitpayload) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Form submitted |
|
||||||
|
|
||||||
|
### /apps/{app_id}/run
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| payload | body | | Yes | [AppRunRequest](#apprunrequest) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Run result (SSE stream) |
|
||||||
|
|
||||||
|
### /apps/{app_id}/tasks/{task_id}/events
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| task_id | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | SSE event stream |
|
||||||
|
|
||||||
|
### /apps/{app_id}/tasks/{task_id}/stop
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| app_id | path | | Yes | string |
|
||||||
|
| task_id | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Task stopped |
|
||||||
|
|
||||||
|
### /oauth/device/approve
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| payload | body | | Yes | [DeviceMutateRequest](#devicemutaterequest) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Approved | [DeviceMutateResponse](#devicemutateresponse) |
|
||||||
|
|
||||||
|
### /oauth/device/code
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| payload | body | | Yes | [DeviceCodeRequest](#devicecoderequest) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Device code created | [DeviceCodeResponse](#devicecoderesponse) |
|
||||||
|
|
||||||
|
### /oauth/device/deny
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| payload | body | | Yes | [DeviceMutateRequest](#devicemutaterequest) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Denied | [DeviceMutateResponse](#devicemutateresponse) |
|
||||||
|
|
||||||
|
### /oauth/device/lookup
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| user_code | query | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Device lookup result | [DeviceLookupResponse](#devicelookupresponse) |
|
||||||
|
|
||||||
|
### /oauth/device/token
|
||||||
|
|
||||||
|
#### POST
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| payload | body | | Yes | [DevicePollRequest](#devicepollrequest) |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description |
|
||||||
|
| ---- | ----------- |
|
||||||
|
| 200 | Success |
|
||||||
|
|
||||||
|
### /permitted-external-apps
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Permitted external apps list | [PermittedExternalAppsListResponse](#permittedexternalappslistresponse) |
|
||||||
|
|
||||||
|
### /workspaces
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Workspace list | [WorkspaceListResponse](#workspacelistresponse) |
|
||||||
|
|
||||||
|
### /workspaces/{workspace_id}
|
||||||
|
|
||||||
|
#### GET
|
||||||
|
##### Parameters
|
||||||
|
|
||||||
|
| Name | Located in | Description | Required | Schema |
|
||||||
|
| ---- | ---------- | ----------- | -------- | ------ |
|
||||||
|
| workspace_id | path | | Yes | string |
|
||||||
|
|
||||||
|
##### Responses
|
||||||
|
|
||||||
|
| Code | Description | Schema |
|
||||||
|
| ---- | ----------- | ------ |
|
||||||
|
| 200 | Workspace detail | [WorkspaceDetailResponse](#workspacedetailresponse) |
|
||||||
|
|
||||||
|
---
|
||||||
|
### Models
|
||||||
|
|
||||||
|
#### AccountPayload
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| email | string | | Yes |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
|
||||||
|
#### AccountResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| account | [AccountPayload](#accountpayload) | | No |
|
||||||
|
| default_workspace_id | string | | No |
|
||||||
|
| subject_email | string | | No |
|
||||||
|
| subject_issuer | string | | No |
|
||||||
|
| subject_type | string | | Yes |
|
||||||
|
| workspaces | [ [WorkspacePayload](#workspacepayload) ] | | No |
|
||||||
|
|
||||||
|
#### AppDescribeInfo
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| author | string | | No |
|
||||||
|
| description | string | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| is_agent | boolean | | No |
|
||||||
|
| mode | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| service_api_enabled | boolean | | Yes |
|
||||||
|
| tags | [ [TagItem](#tagitem) ] | | No |
|
||||||
|
| updated_at | string | | No |
|
||||||
|
|
||||||
|
#### AppDescribeQuery
|
||||||
|
|
||||||
|
`?fields=` allow-list for GET /apps/<id>/describe.
|
||||||
|
|
||||||
|
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| fields | [ string ] | | No |
|
||||||
|
| workspace_id | string | | No |
|
||||||
|
|
||||||
|
#### AppDescribeResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| info | [AppDescribeInfo](#appdescribeinfo) | | No |
|
||||||
|
| input_schema | object | | No |
|
||||||
|
| parameters | object | | No |
|
||||||
|
|
||||||
|
#### AppInfoResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| author | string | | No |
|
||||||
|
| description | string | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| mode | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| tags | [ [TagItem](#tagitem) ] | | No |
|
||||||
|
|
||||||
|
#### AppListQuery
|
||||||
|
|
||||||
|
mode is a closed enum.
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| limit | integer | | No |
|
||||||
|
| mode | [AppMode](#appmode) | | No |
|
||||||
|
| name | string | | No |
|
||||||
|
| page | integer | | No |
|
||||||
|
| tag | string | | No |
|
||||||
|
| workspace_id | string | | Yes |
|
||||||
|
|
||||||
|
#### AppListResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| data | [ [AppListRow](#applistrow) ] | | Yes |
|
||||||
|
| has_more | boolean | | Yes |
|
||||||
|
| limit | integer | | Yes |
|
||||||
|
| page | integer | | Yes |
|
||||||
|
| total | integer | | Yes |
|
||||||
|
|
||||||
|
#### AppListRow
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| created_by_name | string | | No |
|
||||||
|
| description | string | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| mode | [AppMode](#appmode) | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| tags | [ [TagItem](#tagitem) ] | | No |
|
||||||
|
| updated_at | string | | No |
|
||||||
|
| workspace_id | string | | No |
|
||||||
|
| workspace_name | string | | No |
|
||||||
|
|
||||||
|
#### AppMode
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| AppMode | string | | |
|
||||||
|
|
||||||
|
#### AppRunRequest
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| auto_generate_name | boolean | | No |
|
||||||
|
| conversation_id | string | | No |
|
||||||
|
| files | [ object ] | | No |
|
||||||
|
| inputs | object | | Yes |
|
||||||
|
| query | string | | No |
|
||||||
|
| workflow_id | string | | No |
|
||||||
|
| workspace_id | string | | No |
|
||||||
|
|
||||||
|
#### DeviceCodeRequest
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| client_id | string | | Yes |
|
||||||
|
| device_label | string | | Yes |
|
||||||
|
|
||||||
|
#### DeviceCodeResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| device_code | string | | Yes |
|
||||||
|
| expires_in | integer | | Yes |
|
||||||
|
| interval | integer | | Yes |
|
||||||
|
| user_code | string | | Yes |
|
||||||
|
| verification_uri | string | | Yes |
|
||||||
|
|
||||||
|
#### DeviceLookupQuery
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| user_code | string | | Yes |
|
||||||
|
|
||||||
|
#### DeviceLookupResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| client_id | string | | No |
|
||||||
|
| expires_in_remaining | integer | | No |
|
||||||
|
| valid | boolean | | Yes |
|
||||||
|
|
||||||
|
#### DeviceMutateRequest
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| user_code | string | | Yes |
|
||||||
|
|
||||||
|
#### DeviceMutateResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| status | string | | Yes |
|
||||||
|
|
||||||
|
#### DevicePollRequest
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| client_id | string | | Yes |
|
||||||
|
| device_code | string | | Yes |
|
||||||
|
|
||||||
|
#### FileResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| conversation_id | string | | No |
|
||||||
|
| created_at | integer | | No |
|
||||||
|
| created_by | string | | No |
|
||||||
|
| extension | string | | No |
|
||||||
|
| file_key | string | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| mime_type | string | | No |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| original_url | string | | No |
|
||||||
|
| preview_url | string | | No |
|
||||||
|
| size | integer | | Yes |
|
||||||
|
| source_url | string | | No |
|
||||||
|
| tenant_id | string | | No |
|
||||||
|
| user_id | string | | No |
|
||||||
|
|
||||||
|
#### HumanInputFormSubmitPayload
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| action | string | | Yes |
|
||||||
|
| inputs | object | | Yes |
|
||||||
|
|
||||||
|
#### JsonValue
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| JsonValue | | | |
|
||||||
|
|
||||||
|
#### MessageMetadata
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| retriever_resources | [ object ] | | No |
|
||||||
|
| usage | [UsageInfo](#usageinfo) | | No |
|
||||||
|
|
||||||
|
#### PermittedExternalAppsListQuery
|
||||||
|
|
||||||
|
Strict (extra='forbid').
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| limit | integer | | No |
|
||||||
|
| mode | [AppMode](#appmode) | | No |
|
||||||
|
| name | string | | No |
|
||||||
|
| page | integer | | No |
|
||||||
|
|
||||||
|
#### PermittedExternalAppsListResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| data | [ [AppListRow](#applistrow) ] | | Yes |
|
||||||
|
| has_more | boolean | | Yes |
|
||||||
|
| limit | integer | | Yes |
|
||||||
|
| page | integer | | Yes |
|
||||||
|
| total | integer | | Yes |
|
||||||
|
|
||||||
|
#### RevokeResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| status | string | | Yes |
|
||||||
|
|
||||||
|
#### ServerVersionResponse
|
||||||
|
|
||||||
|
Meta endpoint payload for `GET /openapi/v1/_version` — no auth required.
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| edition | string | *Enum:* `"CLOUD"`, `"SELF_HOSTED"` | Yes |
|
||||||
|
| version | string | | Yes |
|
||||||
|
|
||||||
|
#### SessionListResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| data | [ [SessionRow](#sessionrow) ] | | Yes |
|
||||||
|
| has_more | boolean | | Yes |
|
||||||
|
| limit | integer | | Yes |
|
||||||
|
| page | integer | | Yes |
|
||||||
|
| total | integer | | Yes |
|
||||||
|
|
||||||
|
#### SessionRow
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| client_id | string | | Yes |
|
||||||
|
| created_at | string | | No |
|
||||||
|
| device_label | string | | Yes |
|
||||||
|
| expires_at | string | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| last_used_at | string | | No |
|
||||||
|
| prefix | string | | Yes |
|
||||||
|
|
||||||
|
#### TagItem
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| name | string | | Yes |
|
||||||
|
|
||||||
|
#### UsageInfo
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| completion_tokens | integer | | No |
|
||||||
|
| prompt_tokens | integer | | No |
|
||||||
|
| total_tokens | integer | | No |
|
||||||
|
|
||||||
|
#### WorkflowRunData
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| created_at | integer | | No |
|
||||||
|
| elapsed_time | number | | No |
|
||||||
|
| error | string | | No |
|
||||||
|
| finished_at | integer | | No |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| outputs | object | | No |
|
||||||
|
| status | string | | Yes |
|
||||||
|
| total_steps | integer | | No |
|
||||||
|
| total_tokens | integer | | No |
|
||||||
|
| workflow_id | string | | Yes |
|
||||||
|
|
||||||
|
#### WorkspaceDetailResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| created_at | string | | No |
|
||||||
|
| current | boolean | | Yes |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| role | string | | Yes |
|
||||||
|
| status | string | | Yes |
|
||||||
|
|
||||||
|
#### WorkspaceListResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| workspaces | [ [WorkspaceSummaryResponse](#workspacesummaryresponse) ] | | Yes |
|
||||||
|
|
||||||
|
#### WorkspacePayload
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| role | string | | Yes |
|
||||||
|
|
||||||
|
#### WorkspaceSummaryResponse
|
||||||
|
|
||||||
|
| Name | Type | Description | Required |
|
||||||
|
| ---- | ---- | ----------- | -------- |
|
||||||
|
| current | boolean | | Yes |
|
||||||
|
| id | string | | Yes |
|
||||||
|
| name | string | | Yes |
|
||||||
|
| role | string | | Yes |
|
||||||
|
| status | string | | Yes |
|
||||||
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
|
||||||
|
(token_id stays for audits) so rows accumulate across re-logins, and
|
||||||
|
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||||
|
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import click
|
||||||
|
from sqlalchemy import delete, or_, select
|
||||||
|
|
||||||
|
import app
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.oauth import OAuthAccessToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DELETE_BATCH_SIZE = 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.celery.task(queue="retention")
|
||||||
|
def clean_oauth_access_tokens_task():
|
||||||
|
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
|
||||||
|
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
|
||||||
|
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
|
candidates = or_(
|
||||||
|
OAuthAccessToken.revoked_at < cutoff,
|
||||||
|
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||||
|
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
|
||||||
|
)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
while True:
|
||||||
|
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
|
||||||
|
if not ids:
|
||||||
|
break
|
||||||
|
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
|
||||||
|
db.session.commit()
|
||||||
|
total += len(ids)
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
click.echo(
|
||||||
|
click.style(
|
||||||
|
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
@ -8,7 +8,8 @@ from hashlib import sha256
|
|||||||
from typing import Any, TypedDict, cast
|
from typing import Any, TypedDict, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||||
from sqlalchemy import delete, func, select, update
|
from sqlalchemy import Row, delete, func, select, update
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
|
|
||||||
@ -163,6 +164,41 @@ class AccountService:
|
|||||||
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||||
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_account_by_email(session: Session | scoped_session, email: str) -> Account | None:
|
||||||
|
"""Plain ``Account`` getter keyed by email. Case-sensitive — use
|
||||||
|
:meth:`has_active_account_with_email` for the case-insensitive
|
||||||
|
existence check that backs the SSO collision rule.
|
||||||
|
"""
|
||||||
|
return session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_active_account_with_email(session: Session | scoped_session, email: str) -> bool:
|
||||||
|
if not email:
|
||||||
|
return False
|
||||||
|
normalized = email.strip().lower()
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
row = session.execute(
|
||||||
|
select(Account.id).where(
|
||||||
|
func.lower(Account.email) == normalized,
|
||||||
|
Account.status == AccountStatus.ACTIVE,
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_account_by_id(session: Session | scoped_session, account_id: str) -> Account | None:
|
||||||
|
"""Plain ``Account`` getter — no banned check, no tenant rotation,
|
||||||
|
no ``last_active_at`` write. Use this from read-only identity
|
||||||
|
endpoints (``/openapi/v1/account``) where ``load_user``'s
|
||||||
|
side-effects (current-tenant assignment, commit) are unwanted.
|
||||||
|
|
||||||
|
``session`` is injected by the caller so this service stays free
|
||||||
|
of the Flask-scoped ``db.session`` import.
|
||||||
|
"""
|
||||||
|
return session.get(Account, account_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_user(user_id: str) -> None | Account:
|
def load_user(user_id: str) -> None | Account:
|
||||||
account = db.session.get(Account, user_id)
|
account = db.session.get(Account, user_id)
|
||||||
@ -1182,6 +1218,127 @@ class TenantService:
|
|||||||
).all()
|
).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_account_memberships(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
account_id: str,
|
||||||
|
) -> list[Row[tuple[TenantAccountJoin, Tenant]]]:
|
||||||
|
"""Return ``(TenantAccountJoin, Tenant)`` rows for every workspace
|
||||||
|
the account belongs to. Unlike :meth:`get_join_tenants` this keeps
|
||||||
|
the join row so callers can read ``role``/``current`` alongside the
|
||||||
|
tenant — used by ``/openapi/v1/account`` to render workspace
|
||||||
|
membership + pick the default workspace.
|
||||||
|
|
||||||
|
``session`` is injected by the caller so this service stays free
|
||||||
|
of the Flask-scoped ``db.session`` import.
|
||||||
|
|
||||||
|
No tenant-status filter: parity with the legacy controller query
|
||||||
|
(the openapi identity endpoint listed all joined tenants).
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
session.query(TenantAccountJoin, Tenant)
|
||||||
|
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||||
|
.filter(TenantAccountJoin.account_id == account_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_workspaces_for_account(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
account_id: str,
|
||||||
|
) -> list[Row[tuple[Tenant, TenantAccountJoin]]]:
|
||||||
|
"""``(Tenant, TenantAccountJoin)`` rows for every workspace the
|
||||||
|
account belongs to, ordered by ``Tenant.created_at`` ASC — the
|
||||||
|
canonical ordering for ``/openapi/v1/workspaces``.
|
||||||
|
|
||||||
|
Distinct from :meth:`get_account_memberships`: tuple order is
|
||||||
|
flipped (tenant first) and rows are sorted, so the workspace
|
||||||
|
listing is stable across requests.
|
||||||
|
"""
|
||||||
|
return list(
|
||||||
|
session.execute(
|
||||||
|
select(Tenant, TenantAccountJoin)
|
||||||
|
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||||
|
.where(TenantAccountJoin.account_id == account_id)
|
||||||
|
.order_by(Tenant.created_at.asc())
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def account_belongs_to_tenant(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
account_id: uuid.UUID | str | None,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Existence check for ``TenantAccountJoin(account_id, tenant_id)``.
|
||||||
|
Backs the CE-deployment membership fallback in
|
||||||
|
``controllers.openapi.auth.strategies.MembershipStrategy``.
|
||||||
|
|
||||||
|
``None``/empty ``account_id`` short-circuits to ``False`` so SSO
|
||||||
|
bearers (no account) and missing identity collapse cleanly.
|
||||||
|
"""
|
||||||
|
if not account_id:
|
||||||
|
return False
|
||||||
|
row = session.execute(
|
||||||
|
select(TenantAccountJoin.id).where(
|
||||||
|
TenantAccountJoin.tenant_id == tenant_id,
|
||||||
|
TenantAccountJoin.account_id == account_id,
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None:
|
||||||
|
"""Plain ``session.get(Tenant, tenant_id)`` — no status filter.
|
||||||
|
Callers map ``status == ARCHIVE`` to their own error code (the
|
||||||
|
openapi auth pipeline raises 403 ``workspace unavailable``).
|
||||||
|
"""
|
||||||
|
return session.get(Tenant, tenant_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenants_by_ids(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
tenant_ids: list[str],
|
||||||
|
) -> list[Tenant]:
|
||||||
|
"""Bulk ``Tenant`` fetch by primary-key list. Order is unspecified
|
||||||
|
— callers index by ``tenant.id`` (e.g. for cross-tenant denorm
|
||||||
|
in ``/openapi/v1/permitted-external-apps``).
|
||||||
|
|
||||||
|
Empty input short-circuits to ``[]`` to avoid emitting an
|
||||||
|
``IN ()`` SQL fragment.
|
||||||
|
"""
|
||||||
|
if not tenant_ids:
|
||||||
|
return []
|
||||||
|
return list(session.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_name(session: Session | scoped_session, tenant_id: str) -> str | None:
|
||||||
|
"""Single-column tenant name read. Used by openapi list endpoints
|
||||||
|
to denormalize ``workspace_name`` onto each row without dragging
|
||||||
|
the full ``Tenant`` ORM entity through.
|
||||||
|
"""
|
||||||
|
return session.execute(select(Tenant.name).where(Tenant.id == tenant_id)).scalar_one_or_none()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_workspace_for_account(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
account_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
) -> Row[tuple[Tenant, TenantAccountJoin]] | None:
|
||||||
|
"""Single ``(Tenant, TenantAccountJoin)`` row scoped to the
|
||||||
|
account's membership in ``workspace_id``. ``None`` on non-member
|
||||||
|
— the caller maps that to 404 (not 403) so workspace IDs don't
|
||||||
|
leak across tenants via response codes.
|
||||||
|
"""
|
||||||
|
return session.execute(
|
||||||
|
select(Tenant, TenantAccountJoin)
|
||||||
|
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||||
|
.where(
|
||||||
|
Tenant.id == workspace_id,
|
||||||
|
TenantAccountJoin.account_id == account_id,
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_current_tenant_by_account(account: Account):
|
def get_current_tenant_by_account(account: Account):
|
||||||
"""Get tenant by account and add the role"""
|
"""Get tenant by account and add the role"""
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Literal, TypedDict, cast
|
from typing import Any, Literal, TypedDict, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_sqlalchemy.pagination import Pagination
|
from flask_sqlalchemy.pagination import Pagination
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.model_template import default_app_templates
|
from constants.model_template import default_app_templates
|
||||||
@ -26,6 +28,7 @@ from models.tools import ApiToolProvider
|
|||||||
from services.billing_service import BillingService
|
from services.billing_service import BillingService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
|
||||||
from services.tag_service import TagService
|
from services.tag_service import TagService
|
||||||
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
|
||||||
|
|
||||||
@ -39,6 +42,8 @@ class AppListParams(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
tag_ids: list[str] | None = None
|
tag_ids: list[str] | None = None
|
||||||
is_created_by_me: bool | None = None
|
is_created_by_me: bool | None = None
|
||||||
|
status: str | None = None
|
||||||
|
openapi_visible: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CreateAppParams(BaseModel):
|
class CreateAppParams(BaseModel):
|
||||||
@ -54,6 +59,51 @@ class CreateAppParams(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AppService:
|
class AppService:
|
||||||
|
@staticmethod
|
||||||
|
def get_app_by_id(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
app_id: str,
|
||||||
|
) -> App | None:
|
||||||
|
return session.get(App, app_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_visible_app_by_id(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
app_id: str,
|
||||||
|
) -> App | None:
|
||||||
|
app = session.get(App, app_id)
|
||||||
|
if not app or app.status != "normal" or not is_openapi_visible(app):
|
||||||
|
return None
|
||||||
|
return app
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_visible_apps_by_ids(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
app_ids: Sequence[str],
|
||||||
|
) -> list[App]:
|
||||||
|
if not app_ids:
|
||||||
|
return []
|
||||||
|
return list(session.execute(apply_openapi_gate(select(App).where(App.id.in_(list(app_ids))))).scalars().all())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_visible_apps_by_name(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[App]:
|
||||||
|
return list(
|
||||||
|
session.execute(
|
||||||
|
apply_openapi_gate(
|
||||||
|
select(App).where(
|
||||||
|
App.name == name,
|
||||||
|
App.tenant_id == tenant_id,
|
||||||
|
App.status == "normal",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalars()
|
||||||
|
)
|
||||||
|
|
||||||
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
|
||||||
"""
|
"""
|
||||||
Get app list with pagination
|
Get app list with pagination
|
||||||
@ -75,6 +125,14 @@ class AppService:
|
|||||||
elif params.mode == "agent-chat":
|
elif params.mode == "agent-chat":
|
||||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||||
|
|
||||||
|
if params.status:
|
||||||
|
filters.append(App.status == params.status)
|
||||||
|
# OpenAPI surface visibility gate. Pushed into the query so
|
||||||
|
# `pagination.total` reflects only apps the openapi caller can
|
||||||
|
# actually reach — post-filtering by enable_api after the page
|
||||||
|
# arrives would make `total` page-dependent.
|
||||||
|
if params.openapi_visible:
|
||||||
|
filters.append(App.enable_api.is_(True))
|
||||||
if params.is_created_by_me:
|
if params.is_created_by_me:
|
||||||
filters.append(App.created_by == user_id)
|
filters.append(App.created_by == user_id)
|
||||||
if params.name:
|
if params.name:
|
||||||
|
|||||||
@ -14,13 +14,13 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
|||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from graphon.model_runtime.entities.provider_entities import FormType
|
from graphon.model_runtime.entities.provider_entities import FormType
|
||||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||||
from models.provider_ids import DatasourceProviderID
|
from models.provider_ids import DatasourceProviderID
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
44
api/services/enterprise/app_permitted_service.py
Normal file
44
api/services/enterprise/app_permitted_service.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from werkzeug.exceptions import ServiceUnavailable
|
||||||
|
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.errors.enterprise import EnterpriseAPIError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class PermittedAppsPage:
|
||||||
|
app_ids: list[str]
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
def list_permitted_apps(
|
||||||
|
*,
|
||||||
|
page: int,
|
||||||
|
limit: int,
|
||||||
|
mode: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> PermittedAppsPage:
|
||||||
|
try:
|
||||||
|
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
|
||||||
|
page=page, limit=limit, mode=mode, name=name
|
||||||
|
)
|
||||||
|
except EnterpriseAPIError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"permitted_apps EE call failed: status=%s message=%s",
|
||||||
|
getattr(exc, "status_code", None),
|
||||||
|
str(exc),
|
||||||
|
)
|
||||||
|
raise ServiceUnavailable("permitted_apps_unavailable") from exc
|
||||||
|
|
||||||
|
return PermittedAppsPage(
|
||||||
|
app_ids=[row["appId"] for row in body.get("data", [])],
|
||||||
|
total=int(body.get("total", 0)),
|
||||||
|
has_more=bool(body.get("hasMore", False)),
|
||||||
|
)
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -24,10 +25,22 @@ VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
|||||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAccessMode(enum.StrEnum):
|
||||||
|
PUBLIC = "public"
|
||||||
|
PRIVATE = "private"
|
||||||
|
PRIVATE_ALL = "private_all"
|
||||||
|
SSO_VERIFIED = "sso_verified"
|
||||||
|
|
||||||
|
|
||||||
|
PERMISSION_CHECK_MODES: frozenset[WebAppAccessMode] = frozenset(
|
||||||
|
{WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebAppSettings(BaseModel):
|
class WebAppSettings(BaseModel):
|
||||||
access_mode: str = Field(
|
access_mode: str = Field(
|
||||||
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
|
description=f"Access mode for the web app. One of: {', '.join(m.value for m in WebAppAccessMode)}",
|
||||||
default="private",
|
default=WebAppAccessMode.PRIVATE.value,
|
||||||
alias="accessMode",
|
alias="accessMode",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,6 +121,15 @@ class EnterpriseService:
|
|||||||
def get_workspace_info(cls, tenant_id: str):
|
def get_workspace_info(cls, tenant_id: str):
|
||||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||||
|
return EnterpriseRequest.send_request(
|
||||||
|
"POST",
|
||||||
|
"/device-flow/sso-initiate",
|
||||||
|
json={"signed_state": signed_state},
|
||||||
|
raise_for_status=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||||
"""
|
"""
|
||||||
@ -219,8 +241,9 @@ class EnterpriseService:
|
|||||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
if access_mode not in ["public", "private", "private_all"]:
|
allowed = {WebAppAccessMode.PUBLIC, WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
if access_mode not in allowed:
|
||||||
|
raise ValueError(f"access_mode must be one of: {', '.join(m.value for m in allowed)}")
|
||||||
|
|
||||||
data = {"appId": app_id, "accessMode": access_mode}
|
data = {"appId": app_id, "accessMode": access_mode}
|
||||||
|
|
||||||
@ -236,6 +259,32 @@ class EnterpriseService:
|
|||||||
params = {"appId": app_id}
|
params = {"appId": app_id}
|
||||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_externally_accessible_apps(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
page: int,
|
||||||
|
limit: int,
|
||||||
|
mode: str | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
|
||||||
|
|
||||||
|
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
|
||||||
|
"total": int, "hasMore": bool}``.
|
||||||
|
"""
|
||||||
|
body: dict[str, str | int] = {"page": page, "limit": limit}
|
||||||
|
if mode is not None:
|
||||||
|
body["mode"] = mode
|
||||||
|
if name is not None:
|
||||||
|
body["name"] = name
|
||||||
|
return EnterpriseRequest.send_request(
|
||||||
|
"POST",
|
||||||
|
"/webapp/externally-accessible-apps",
|
||||||
|
json=body,
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||||
|
|||||||
572
api/services/oauth_device_flow.py
Normal file
572
api/services/oauth_device_flow.py
Normal file
@ -0,0 +1,572 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
|
from sqlalchemy import and_, func, select, update
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.orm import Session, scoped_session
|
||||||
|
|
||||||
|
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
|
||||||
|
from models.oauth import OAuthAccessToken
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Redis state machine — device_code + user_code ephemeral state
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
_DEVICE_CODE_KEY_PREFIX = "device_code:"
|
||||||
|
_USER_CODE_KEY_PREFIX = "user_code:"
|
||||||
|
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
|
||||||
|
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
|
||||||
|
|
||||||
|
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
|
||||||
|
# not both observe APPROVED — only the winner gets the plaintext token,
|
||||||
|
# the loser sees nil and the caller maps that to expired_token.
|
||||||
|
_CONSUME_ON_POLL_LUA = """
|
||||||
|
local raw = redis.call('GET', KEYS[1])
|
||||||
|
if not raw then return nil end
|
||||||
|
local ok, decoded = pcall(cjson.decode, raw)
|
||||||
|
if not ok then return nil end
|
||||||
|
if decoded.status == 'pending' then return nil end
|
||||||
|
if decoded.user_code then
|
||||||
|
redis.call('DEL', ARGV[1] .. decoded.user_code)
|
||||||
|
end
|
||||||
|
redis.call('DEL', KEYS[1])
|
||||||
|
return raw
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||||
|
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||||
|
|
||||||
|
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||||
|
USER_CODE_SEGMENT_LEN = 4
|
||||||
|
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
|
||||||
|
|
||||||
|
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceFlowStatus(StrEnum):
|
||||||
|
PENDING = "pending"
|
||||||
|
APPROVED = "approved"
|
||||||
|
DENIED = "denied"
|
||||||
|
|
||||||
|
|
||||||
|
class SlowDownDecision(StrEnum):
|
||||||
|
OK = "ok"
|
||||||
|
SLOW_DOWN = "slow_down"
|
||||||
|
|
||||||
|
|
||||||
|
class PollPayload(TypedDict):
|
||||||
|
"""Body served by the unauthenticated poll endpoint
|
||||||
|
(`POST /openapi/v1/oauth/device/token`) once approve has run.
|
||||||
|
|
||||||
|
A single shape across both branches so the CLI/SPA can parse one
|
||||||
|
contract:
|
||||||
|
|
||||||
|
- ``account`` branch (built in `controllers.openapi.oauth_device.
|
||||||
|
_build_account_poll_payload`) populates ``account`` + ``workspaces``
|
||||||
|
+ ``default_workspace_id`` and omits the SSO-only fields.
|
||||||
|
- ``external_sso`` branch (built in
|
||||||
|
`controllers.openapi.oauth_device_sso.approve_external`) populates
|
||||||
|
``subject_email`` + ``subject_issuer`` and zero-fills the
|
||||||
|
account/workspace fields (``None`` / ``[]``).
|
||||||
|
|
||||||
|
Pre-rendering here means the unauthenticated poll handler doesn't
|
||||||
|
re-query accounts/tenants for authz data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
expires_at: str
|
||||||
|
subject_type: SubjectType
|
||||||
|
account: dict[str, object] | None
|
||||||
|
workspaces: list[dict[str, object]]
|
||||||
|
default_workspace_id: str | None
|
||||||
|
token_id: str
|
||||||
|
subject_email: NotRequired[str]
|
||||||
|
subject_issuer: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceFlowState:
|
||||||
|
"""``minted_token`` is plaintext between approve and the next poll;
|
||||||
|
DEL'd after the poll reads it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_code: str
|
||||||
|
client_id: str
|
||||||
|
device_label: str
|
||||||
|
status: DeviceFlowStatus
|
||||||
|
subject_email: str | None = None
|
||||||
|
account_id: str | None = None
|
||||||
|
subject_issuer: str | None = None
|
||||||
|
minted_token: str | None = None
|
||||||
|
token_id: str | None = None
|
||||||
|
created_at: str = ""
|
||||||
|
created_ip: str = ""
|
||||||
|
last_poll_at: str = ""
|
||||||
|
poll_payload: PollPayload | None = field(default=None)
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
return json.dumps(asdict(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, raw: str) -> DeviceFlowState:
|
||||||
|
data = json.loads(raw)
|
||||||
|
if "status" in data:
|
||||||
|
data["status"] = DeviceFlowStatus(data["status"])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_device_code() -> str:
|
||||||
|
return "dc_" + secrets.token_urlsafe(24)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_user_code_segment() -> str:
|
||||||
|
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
|
||||||
|
|
||||||
|
|
||||||
|
def _random_user_code() -> str:
|
||||||
|
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||||
|
|
||||||
|
|
||||||
|
class StateNotFoundError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTransitionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UserCodeExhaustedError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceFlowRedis:
|
||||||
|
def __init__(self, redis_client) -> None:
|
||||||
|
self._redis = redis_client
|
||||||
|
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
|
||||||
|
|
||||||
|
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
|
||||||
|
device_code = _random_device_code()
|
||||||
|
user_code = self._claim_user_code(device_code)
|
||||||
|
state = DeviceFlowState(
|
||||||
|
user_code=user_code,
|
||||||
|
client_id=client_id,
|
||||||
|
device_label=device_label,
|
||||||
|
status=DeviceFlowStatus.PENDING,
|
||||||
|
created_at=datetime.now(UTC).isoformat(),
|
||||||
|
created_ip=created_ip,
|
||||||
|
)
|
||||||
|
self._redis.setex(
|
||||||
|
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||||
|
DEVICE_FLOW_TTL_SECONDS,
|
||||||
|
state.to_json(),
|
||||||
|
)
|
||||||
|
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
|
||||||
|
|
||||||
|
def _claim_user_code(self, device_code: str) -> str:
|
||||||
|
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
|
||||||
|
user_code = _random_user_code()
|
||||||
|
key = USER_CODE_KEY_FMT.format(code=user_code)
|
||||||
|
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||||
|
if ok:
|
||||||
|
return user_code
|
||||||
|
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
|
||||||
|
|
||||||
|
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||||
|
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||||
|
if not raw_dc:
|
||||||
|
return None
|
||||||
|
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
|
||||||
|
state = self._load_state(device_code)
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
return device_code, state
|
||||||
|
|
||||||
|
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
|
||||||
|
return self._load_state(device_code)
|
||||||
|
|
||||||
|
def _load_state(self, device_code: str) -> DeviceFlowState | None:
|
||||||
|
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||||
|
try:
|
||||||
|
return DeviceFlowState.from_json(text_)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
logger.exception("device_flow: corrupt state for %s", device_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def approve(
|
||||||
|
self,
|
||||||
|
device_code: str,
|
||||||
|
subject_email: str,
|
||||||
|
account_id: str | None,
|
||||||
|
minted_token: str,
|
||||||
|
token_id: str,
|
||||||
|
subject_issuer: str | None = None,
|
||||||
|
poll_payload: PollPayload | None = None,
|
||||||
|
) -> None:
|
||||||
|
state = self._load_state(device_code)
|
||||||
|
if state is None:
|
||||||
|
raise StateNotFoundError(device_code)
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
raise InvalidTransitionError(f"cannot approve {state.status}")
|
||||||
|
|
||||||
|
state.status = DeviceFlowStatus.APPROVED
|
||||||
|
state.subject_email = subject_email
|
||||||
|
state.account_id = account_id
|
||||||
|
state.subject_issuer = subject_issuer
|
||||||
|
state.minted_token = minted_token
|
||||||
|
state.token_id = token_id
|
||||||
|
state.poll_payload = poll_payload
|
||||||
|
|
||||||
|
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
|
||||||
|
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
|
||||||
|
|
||||||
|
def deny(self, device_code: str) -> None:
|
||||||
|
state = self._load_state(device_code)
|
||||||
|
if state is None:
|
||||||
|
raise StateNotFoundError(device_code)
|
||||||
|
if state.status is not DeviceFlowStatus.PENDING:
|
||||||
|
raise InvalidTransitionError(f"cannot deny {state.status}")
|
||||||
|
state.status = DeviceFlowStatus.DENIED
|
||||||
|
self._redis.setex(
|
||||||
|
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||||
|
self._remaining_ttl(device_code, floor=1),
|
||||||
|
state.to_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
|
||||||
|
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
|
||||||
|
single Redis transaction so only one of N concurrent pollers
|
||||||
|
observes the APPROVED state. Losers get None, mapped to
|
||||||
|
expired_token by the caller.
|
||||||
|
"""
|
||||||
|
raw = self._consume_on_poll_script(
|
||||||
|
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
|
||||||
|
args=[_USER_CODE_KEY_PREFIX],
|
||||||
|
)
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||||
|
try:
|
||||||
|
return DeviceFlowState.from_json(text_)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
logger.exception("device_flow: corrupt state on consume %s", device_code)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||||
|
now = time.time()
|
||||||
|
key = f"device_code:{device_code}:last_poll"
|
||||||
|
prev_raw = self._redis.get(key)
|
||||||
|
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
|
||||||
|
if prev_raw is None:
|
||||||
|
return SlowDownDecision.OK
|
||||||
|
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
|
||||||
|
try:
|
||||||
|
prev = float(prev_s)
|
||||||
|
except ValueError:
|
||||||
|
return SlowDownDecision.OK
|
||||||
|
if now - prev < interval_seconds:
|
||||||
|
return SlowDownDecision.SLOW_DOWN
|
||||||
|
return SlowDownDecision.OK
|
||||||
|
|
||||||
|
def _remaining_ttl(self, device_code: str, floor: int) -> int:
|
||||||
|
"""``max(remaining, floor)`` — guarantees the CLI has at least
|
||||||
|
``floor`` seconds to poll after a near-expiry approve.
|
||||||
|
"""
|
||||||
|
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||||
|
if ttl is None or ttl < 0:
|
||||||
|
return floor
|
||||||
|
return max(int(ttl), floor)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Token mint — generate + upsert
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
|
||||||
|
PREFIX_OAUTH_ACCOUNT = "dfoa_"
|
||||||
|
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
|
||||||
|
|
||||||
|
# Sentinel issuer for account-flow rows. Postgres' default partial unique
|
||||||
|
# index treats NULLs as distinct, which would let two live `dfoa_` rows
|
||||||
|
# share (email, client, device) and break rotate-in-place. Storing a
|
||||||
|
# non-empty literal makes the composite key collide as intended.
|
||||||
|
ACCOUNT_ISSUER_SENTINEL = "dify:account"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class MintResult:
|
||||||
|
"""Plaintext token surfaces to the caller once."""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
token_id: uuid.UUID
|
||||||
|
expires_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class UpsertOutcome:
|
||||||
|
token_id: uuid.UUID
|
||||||
|
rotated: bool
|
||||||
|
old_hash: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token(prefix: str) -> str:
|
||||||
|
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
|
||||||
|
|
||||||
|
|
||||||
|
def sha256_hex(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def mint_oauth_token(
|
||||||
|
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
|
||||||
|
# the wrapper proxies the same execute/commit surface.
|
||||||
|
session: Session | scoped_session,
|
||||||
|
redis_client,
|
||||||
|
*,
|
||||||
|
subject_email: str,
|
||||||
|
subject_issuer: str | None,
|
||||||
|
account_id: str | None,
|
||||||
|
client_id: str,
|
||||||
|
device_label: str,
|
||||||
|
prefix: str,
|
||||||
|
ttl_days: int,
|
||||||
|
) -> MintResult:
|
||||||
|
"""Live row rotates in place via partial unique index
|
||||||
|
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
|
||||||
|
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
|
||||||
|
deleted so stale AuthContext drops immediately.
|
||||||
|
"""
|
||||||
|
if prefix == PREFIX_OAUTH_ACCOUNT:
|
||||||
|
# Account flow always writes the sentinel — caller may pass None
|
||||||
|
# (for clarity) or the sentinel itself; nothing else is valid.
|
||||||
|
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
|
||||||
|
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
|
||||||
|
subject_issuer = ACCOUNT_ISSUER_SENTINEL
|
||||||
|
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
|
||||||
|
# Defense in depth: enterprise canonicalises + rejects empty,
|
||||||
|
# but a regression there must not yield a NULL composite key here.
|
||||||
|
if not subject_issuer or not subject_issuer.strip():
|
||||||
|
raise ValueError("external-SSO token requires non-empty subject_issuer")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown oauth prefix: {prefix!r}")
|
||||||
|
|
||||||
|
token = generate_token(prefix)
|
||||||
|
new_hash = sha256_hex(token)
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
|
||||||
|
|
||||||
|
outcome = _upsert(
|
||||||
|
session,
|
||||||
|
subject_email=subject_email,
|
||||||
|
subject_issuer=subject_issuer,
|
||||||
|
account_id=account_id,
|
||||||
|
client_id=client_id,
|
||||||
|
device_label=device_label,
|
||||||
|
prefix=prefix,
|
||||||
|
new_hash=new_hash,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if outcome.rotated and outcome.old_hash:
|
||||||
|
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
|
||||||
|
|
||||||
|
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
*,
|
||||||
|
subject_email: str,
|
||||||
|
subject_issuer: str | None,
|
||||||
|
account_id: str | None,
|
||||||
|
client_id: str,
|
||||||
|
device_label: str,
|
||||||
|
prefix: str,
|
||||||
|
new_hash: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
) -> UpsertOutcome:
|
||||||
|
# Snapshot prior live row's hash for Redis invalidation post-rotate.
|
||||||
|
# subject_issuer is always non-null here (account flow uses sentinel,
|
||||||
|
# external-SSO is validated upstream), so equality matches the index.
|
||||||
|
prior = session.execute(
|
||||||
|
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
|
||||||
|
.where(
|
||||||
|
OAuthAccessToken.subject_email == subject_email,
|
||||||
|
OAuthAccessToken.subject_issuer == subject_issuer,
|
||||||
|
OAuthAccessToken.client_id == client_id,
|
||||||
|
OAuthAccessToken.device_label == device_label,
|
||||||
|
OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
).first()
|
||||||
|
old_hash = prior.token_hash if prior else None
|
||||||
|
|
||||||
|
insert_stmt = pg_insert(OAuthAccessToken).values(
|
||||||
|
subject_email=subject_email,
|
||||||
|
subject_issuer=subject_issuer,
|
||||||
|
account_id=account_id,
|
||||||
|
client_id=client_id,
|
||||||
|
device_label=device_label,
|
||||||
|
prefix=prefix,
|
||||||
|
token_hash=new_hash,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||||
|
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||||
|
index_where=OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
set_={
|
||||||
|
"token_hash": insert_stmt.excluded.token_hash,
|
||||||
|
"prefix": insert_stmt.excluded.prefix,
|
||||||
|
"account_id": insert_stmt.excluded.account_id,
|
||||||
|
"expires_at": insert_stmt.excluded.expires_at,
|
||||||
|
"created_at": func.now(),
|
||||||
|
"last_used_at": None,
|
||||||
|
},
|
||||||
|
).returning(OAuthAccessToken.id)
|
||||||
|
row = session.execute(upsert_stmt).first()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
raise RuntimeError("oauth_token upsert returned no row")
|
||||||
|
token_id = uuid.UUID(str(row.id))
|
||||||
|
return UpsertOutcome(
|
||||||
|
token_id=token_id,
|
||||||
|
rotated=prior is not None,
|
||||||
|
old_hash=old_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# TTL policy — days new OAuth tokens live
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_OAUTH_TTL_DAYS = 14
|
||||||
|
MIN_TTL_DAYS = 1
|
||||||
|
MAX_TTL_DAYS = 365
|
||||||
|
|
||||||
|
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
|
||||||
|
|
||||||
|
|
||||||
|
def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||||
|
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
|
||||||
|
is deferred; when it lands it wins over the env (Redis-cached 60s).
|
||||||
|
"""
|
||||||
|
_ = tenant_id
|
||||||
|
|
||||||
|
raw = os.environ.get(_TTL_ENV_VAR)
|
||||||
|
if raw is None:
|
||||||
|
return DEFAULT_OAUTH_TTL_DAYS
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
"%s=%r is not an int; falling back to %d",
|
||||||
|
_TTL_ENV_VAR,
|
||||||
|
raw,
|
||||||
|
DEFAULT_OAUTH_TTL_DAYS,
|
||||||
|
)
|
||||||
|
return DEFAULT_OAUTH_TTL_DAYS
|
||||||
|
if value < MIN_TTL_DAYS:
|
||||||
|
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
|
||||||
|
return MIN_TTL_DAYS
|
||||||
|
if value > MAX_TTL_DAYS:
|
||||||
|
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
|
||||||
|
return MAX_TTL_DAYS
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def subject_match_clauses(ctx: AuthContext) -> tuple[Any, ...]:
|
||||||
|
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||||
|
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||||
|
return (
|
||||||
|
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||||
|
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||||
|
OAuthAccessToken.account_id.is_(None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_active_sessions(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
ctx: AuthContext,
|
||||||
|
now: datetime,
|
||||||
|
) -> list[OAuthAccessToken]:
|
||||||
|
return list(
|
||||||
|
session.execute(
|
||||||
|
select(OAuthAccessToken)
|
||||||
|
.where(
|
||||||
|
and_(
|
||||||
|
*subject_match_clauses(ctx),
|
||||||
|
OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
OAuthAccessToken.token_hash.is_not(None),
|
||||||
|
OAuthAccessToken.expires_at > now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(OAuthAccessToken.created_at.desc())
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def token_belongs_to_subject(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
token_id: str,
|
||||||
|
ctx: AuthContext,
|
||||||
|
) -> bool:
|
||||||
|
row = session.execute(
|
||||||
|
select(OAuthAccessToken.id).where(
|
||||||
|
and_(
|
||||||
|
OAuthAccessToken.id == token_id,
|
||||||
|
*subject_match_clauses(ctx),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_oauth_token(
|
||||||
|
session: Session | scoped_session,
|
||||||
|
redis_client: Any,
|
||||||
|
token_id: str,
|
||||||
|
) -> None:
|
||||||
|
row = (
|
||||||
|
session.query(OAuthAccessToken.token_hash)
|
||||||
|
.filter(
|
||||||
|
OAuthAccessToken.id == token_id,
|
||||||
|
OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
pre_revoke_hash = row[0] if row else None
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(OAuthAccessToken)
|
||||||
|
.where(
|
||||||
|
OAuthAccessToken.id == token_id,
|
||||||
|
OAuthAccessToken.revoked_at.is_(None),
|
||||||
|
)
|
||||||
|
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||||
|
)
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if pre_revoke_hash:
|
||||||
|
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||||
0
api/services/openapi/__init__.py
Normal file
0
api/services/openapi/__init__.py
Normal file
52
api/services/openapi/license_gate.py
Normal file
52
api/services/openapi/license_gate.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""License gate for the /openapi/v1/permitted-external-apps* surface.
|
||||||
|
|
||||||
|
EE-only. CE deploys (``ENTERPRISE_ENABLED=false``) skip the gate entirely —
|
||||||
|
the EE blueprint chain is what gives CE deploys no callers on this surface
|
||||||
|
in practice, but the explicit short-circuit avoids any test/fixture that
|
||||||
|
flips the surface on without flipping the license.
|
||||||
|
|
||||||
|
Reuses ``FeatureService.get_system_features()`` so the license status
|
||||||
|
travels the same path as the console reads.
|
||||||
|
|
||||||
|
Companion to ``controllers.console.wraps.enterprise_license_required`` —
|
||||||
|
that one is for console (cookie-authed, force-logout 401). This one is
|
||||||
|
for bearer surface (token-authed, 403 ``license_required``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from services.feature_service import FeatureService, LicenseStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VALID_LICENSE_STATUSES: frozenset[LicenseStatus] = frozenset({LicenseStatus.ACTIVE, LicenseStatus.EXPIRING})
|
||||||
|
|
||||||
|
|
||||||
|
def license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""Decorator form. Raises ``Forbidden('license_required')`` when the EE
|
||||||
|
deployment has no valid license. No-op on CE (``ENTERPRISE_ENABLED=false``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
if dify_config.ENTERPRISE_ENABLED and not _is_license_valid():
|
||||||
|
raise Forbidden(description="license_required")
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
|
def _is_license_valid() -> bool:
|
||||||
|
try:
|
||||||
|
features = FeatureService.get_system_features()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("license_gate: FeatureService.get_system_features failed")
|
||||||
|
return False
|
||||||
|
return features.license.status in _VALID_LICENSE_STATUSES
|
||||||
47
api/services/openapi/mint_policy.py
Normal file
47
api/services/openapi/mint_policy.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Hard mint policy.
|
||||||
|
|
||||||
|
``validate_mint_policy`` cross-checks a (subject_type, prefix, scopes)
|
||||||
|
triple a caller intends to mint against ``MINTABLE_PROFILES`` —
|
||||||
|
the single source of truth in ``libs.oauth_bearer``.
|
||||||
|
|
||||||
|
The defense-in-depth value: if a future caller assembles ``prefix`` or
|
||||||
|
``scopes`` from a non-canonical source (env, request body, plug-in
|
||||||
|
contribution), the mismatch fails closed at approve time before any
|
||||||
|
row hits the DB. When the caller reads straight from
|
||||||
|
``MINTABLE_PROFILES``, the check is a structural pin — it confirms the
|
||||||
|
table entry is well-formed and the caller picked the right key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from libs.oauth_bearer import MINTABLE_PROFILES, Scope, SubjectType
|
||||||
|
|
||||||
|
|
||||||
|
class MintPolicyViolation(Exception): # noqa: N818 — spec-defined name, used in BadRequest message
|
||||||
|
"""Raised on a (subject_type, prefix, scopes) mismatch. Callers translate
|
||||||
|
to 400 ``mint_policy_violation``."""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_mint_policy(
|
||||||
|
*,
|
||||||
|
subject_type: SubjectType,
|
||||||
|
prefix: str,
|
||||||
|
scopes: frozenset[Scope],
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``MintPolicyViolation`` when the triple does not match the
|
||||||
|
canonical ``MINTABLE_PROFILES`` entry for ``subject_type``.
|
||||||
|
"""
|
||||||
|
profile = MINTABLE_PROFILES.get(subject_type)
|
||||||
|
if profile is None:
|
||||||
|
raise MintPolicyViolation(f"mint_policy_violation: unknown subject_type={subject_type!r}")
|
||||||
|
|
||||||
|
drift = []
|
||||||
|
if profile.prefix != prefix:
|
||||||
|
drift.append(f"prefix got={prefix!r} expected={profile.prefix!r}")
|
||||||
|
if frozenset(scopes) != profile.scopes:
|
||||||
|
got = sorted(s.value for s in scopes)
|
||||||
|
want = sorted(s.value for s in profile.scopes)
|
||||||
|
drift.append(f"scopes got={got} expected={want}")
|
||||||
|
|
||||||
|
if drift:
|
||||||
|
raise MintPolicyViolation(f"mint_policy_violation: subject_type={subject_type.value} — " + "; ".join(drift))
|
||||||
32
api/services/openapi/visibility.py
Normal file
32
api/services/openapi/visibility.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Single-source visibility filter for the /openapi/v1/* surface.
|
||||||
|
|
||||||
|
Keep every openapi-surface app query routed through ``_apply_openapi_gate``;
|
||||||
|
retiring or replacing the gate then becomes a one-line change here.
|
||||||
|
|
||||||
|
The Service API (/v1/* app-key surface) does NOT use this helper — that
|
||||||
|
surface has its own per-request guard (``service_api_disabled``) wired
|
||||||
|
into the legacy ``validate_app_token`` decorator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from models.model import App
|
||||||
|
|
||||||
|
|
||||||
|
def apply_openapi_gate(query: Any) -> Any:
|
||||||
|
"""Filter a SQLAlchemy Select/Query to apps visible on /openapi/v1/*.
|
||||||
|
|
||||||
|
Works with both legacy ``Query.filter`` and 2.0-style ``Select.filter``
|
||||||
|
(alias of ``.where``).
|
||||||
|
"""
|
||||||
|
return query.filter(App.enable_api.is_(True))
|
||||||
|
|
||||||
|
|
||||||
|
def is_openapi_visible(app: App) -> bool:
|
||||||
|
"""Per-row counterpart for code paths that fetch an App by primary key
|
||||||
|
(``session.get`` / ``session.scalar``) and need the same visibility check
|
||||||
|
the query gate would have applied.
|
||||||
|
"""
|
||||||
|
return bool(app.enable_api)
|
||||||
@ -22,7 +22,6 @@ from core.helper import marketplace
|
|||||||
from core.plugin.entities.plugin import PluginInstallationSource
|
from core.plugin.entities.plugin import PluginInstallationSource
|
||||||
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
@ -30,6 +29,7 @@ from models.model import App, AppMode, AppModelConfig
|
|||||||
from models.provider_ids import ModelProviderID, ToolProviderID
|
from models.provider_ids import ModelProviderID, ToolProviderID
|
||||||
from models.tools import BuiltinToolProvider
|
from models.tools import BuiltinToolProvider
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -389,19 +389,17 @@ class PluginMigration:
|
|||||||
for plugin_id in batch_plugin_ids
|
for plugin_id in batch_plugin_ids
|
||||||
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
|
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
|
||||||
]
|
]
|
||||||
if batch_plugin_identifiers:
|
manager.install_from_identifiers(
|
||||||
manager.install_from_identifiers(
|
tenant_id,
|
||||||
tenant_id,
|
batch_plugin_identifiers,
|
||||||
batch_plugin_identifiers,
|
PluginInstallationSource.Marketplace,
|
||||||
PluginInstallationSource.Marketplace,
|
metas=[
|
||||||
metas=[
|
{
|
||||||
{
|
"plugin_unique_identifier": identifier,
|
||||||
"plugin_unique_identifier": identifier,
|
}
|
||||||
}
|
for identifier in batch_plugin_identifiers
|
||||||
for identifier in batch_plugin_identifiers
|
],
|
||||||
],
|
)
|
||||||
)
|
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
|
|
||||||
with open(extracted_plugins) as f:
|
with open(extracted_plugins) as f:
|
||||||
"""
|
"""
|
||||||
@ -597,7 +595,6 @@ class PluginMigration:
|
|||||||
for identifier in batch_plugin_identifiers
|
for identifier in batch_plugin_identifiers
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# add to failed
|
# add to failed
|
||||||
failed.extend(batch_plugin_identifiers)
|
failed.extend(batch_plugin_identifiers)
|
||||||
@ -612,7 +609,6 @@ class PluginMigration:
|
|||||||
while not done:
|
while not done:
|
||||||
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
status = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||||
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
for plugin in status.plugins:
|
for plugin in status.plugins:
|
||||||
if plugin.status == PluginInstallTaskStatus.Success:
|
if plugin.status == PluginInstallTaskStatus.Success:
|
||||||
success.append(reverse_map[plugin.plugin_unique_identifier])
|
success.append(reverse_map[plugin.plugin_unique_identifier])
|
||||||
|
|||||||
@ -1,17 +1,8 @@
|
|||||||
"""Core plugin service and tenant-scoped plugin metadata cache ownership.
|
|
||||||
|
|
||||||
This module owns plugin daemon management calls that are shared by API services
|
|
||||||
and core runtimes. Plugin model provider discovery is cached here, alongside
|
|
||||||
plugin install, uninstall, and upgrade invalidation, so all cache mutations for
|
|
||||||
plugin-owned provider metadata stay tenant-scoped and in one place.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
from pydantic import BaseModel
|
||||||
from redis import RedisError
|
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy import delete, select, update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
@ -31,20 +22,16 @@ from core.plugin.entities.plugin import (
|
|||||||
from core.plugin.entities.plugin_daemon import (
|
from core.plugin.entities.plugin_daemon import (
|
||||||
PluginDecodeResponse,
|
PluginDecodeResponse,
|
||||||
PluginInstallTask,
|
PluginInstallTask,
|
||||||
PluginInstallTaskStatus,
|
|
||||||
PluginListResponse,
|
PluginListResponse,
|
||||||
PluginModelProviderEntity,
|
|
||||||
PluginVerification,
|
PluginVerification,
|
||||||
)
|
)
|
||||||
from core.plugin.impl.asset import PluginAssetManager
|
from core.plugin.impl.asset import PluginAssetManager
|
||||||
from core.plugin.impl.debugging import PluginDebuggingClient
|
from core.plugin.impl.debugging import PluginDebuggingClient
|
||||||
from core.plugin.impl.model import PluginModelClient
|
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from graphon.model_runtime.entities.provider_entities import ProviderEntity
|
|
||||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||||
from models.provider_ids import GenericProviderID, ModelProviderID
|
from models.provider_ids import GenericProviderID
|
||||||
from services.enterprise.plugin_manager_service import (
|
from services.enterprise.plugin_manager_service import (
|
||||||
PluginManagerService,
|
PluginManagerService,
|
||||||
PreUninstallPluginRequest,
|
PreUninstallPluginRequest,
|
||||||
@ -53,7 +40,6 @@ from services.errors.plugin import PluginInstallationForbiddenError
|
|||||||
from services.feature_service import FeatureService, PluginInstallationScope
|
from services.feature_service import FeatureService, PluginInstallationScope
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity])
|
|
||||||
|
|
||||||
|
|
||||||
class PluginService:
|
class PluginService:
|
||||||
@ -67,102 +53,6 @@ class PluginService:
|
|||||||
|
|
||||||
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
|
||||||
REDIS_TTL = 60 * 5 # 5 minutes
|
REDIS_TTL = 60 * 5 # 5 minutes
|
||||||
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
|
|
||||||
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
|
|
||||||
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
|
|
||||||
"""
|
|
||||||
Expose a bare provider alias only for the canonical provider mapping.
|
|
||||||
|
|
||||||
Multiple plugins can publish the same short provider slug. If every
|
|
||||||
provider entity keeps that slug in ``provider_name``, callers that still
|
|
||||||
resolve by short name become order-dependent. Restrict the alias to the
|
|
||||||
provider selected by ``ModelProviderID`` so legacy short-name lookups
|
|
||||||
remain deterministic while the runtime surface stays canonical.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
canonical_provider_id = ModelProviderID(provider.provider)
|
|
||||||
except ValueError:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if canonical_provider_id.plugin_id != provider.plugin_id:
|
|
||||||
return ""
|
|
||||||
if canonical_provider_id.provider_name != provider.provider:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return provider.provider
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity:
|
|
||||||
declaration = provider.declaration.model_copy(deep=True)
|
|
||||||
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
|
|
||||||
declaration.provider_name = cls._get_provider_short_name_alias(provider)
|
|
||||||
return declaration
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
|
|
||||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
|
||||||
try:
|
|
||||||
cached_providers = redis_client.get(cache_key)
|
|
||||||
except (RedisError, RuntimeError):
|
|
||||||
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not cached_providers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return tuple(_provider_entities_adapter.validate_json(cached_providers))
|
|
||||||
except (TypeError, ValueError, ValidationError):
|
|
||||||
logger.warning(
|
|
||||||
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
|
|
||||||
)
|
|
||||||
cls.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
|
|
||||||
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
|
|
||||||
try:
|
|
||||||
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
|
|
||||||
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
|
|
||||||
except (RedisError, RuntimeError):
|
|
||||||
logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
|
|
||||||
"""Delete the tenant-scoped plugin model provider list cache."""
|
|
||||||
try:
|
|
||||||
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
|
|
||||||
except (RedisError, RuntimeError):
|
|
||||||
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fetch_plugin_model_providers(
|
|
||||||
cls, *, tenant_id: str, client: PluginModelClient | None = None
|
|
||||||
) -> Sequence[ProviderEntity]:
|
|
||||||
"""
|
|
||||||
Fetch plugin model providers through the tenant-scoped plugin cache.
|
|
||||||
|
|
||||||
Plugin daemon provider discovery and plugin lifecycle cache invalidation
|
|
||||||
are intentionally owned by this service so tenant isolation and cache
|
|
||||||
expiry are handled in one place.
|
|
||||||
"""
|
|
||||||
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
|
|
||||||
if cached_providers is not None:
|
|
||||||
return cached_providers
|
|
||||||
|
|
||||||
model_client = client or PluginModelClient()
|
|
||||||
providers = tuple(
|
|
||||||
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
|
|
||||||
)
|
|
||||||
cls._store_cached_plugin_model_providers(tenant_id, providers)
|
|
||||||
return providers
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
|
||||||
@ -358,18 +248,12 @@ class PluginService:
|
|||||||
Fetch plugin installation tasks
|
Fetch plugin installation tasks
|
||||||
"""
|
"""
|
||||||
manager = PluginInstaller()
|
manager = PluginInstaller()
|
||||||
tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
|
||||||
if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks):
|
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return tasks
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||||
manager = PluginInstaller()
|
manager = PluginInstaller()
|
||||||
task = manager.fetch_plugin_installation_task(tenant_id, task_id)
|
return manager.fetch_plugin_installation_task(tenant_id, task_id)
|
||||||
if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES:
|
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return task
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
def delete_install_task(tenant_id: str, task_id: str) -> bool:
|
||||||
@ -431,7 +315,7 @@ class PluginService:
|
|||||||
# check if the plugin is available to install
|
# check if the plugin is available to install
|
||||||
PluginService._check_plugin_installation_scope(response.verification)
|
PluginService._check_plugin_installation_scope(response.verification)
|
||||||
|
|
||||||
result = manager.upgrade_plugin(
|
return manager.upgrade_plugin(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
original_plugin_unique_identifier,
|
original_plugin_unique_identifier,
|
||||||
new_plugin_unique_identifier,
|
new_plugin_unique_identifier,
|
||||||
@ -440,8 +324,6 @@ class PluginService:
|
|||||||
"plugin_unique_identifier": new_plugin_unique_identifier,
|
"plugin_unique_identifier": new_plugin_unique_identifier,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def upgrade_plugin_with_github(
|
def upgrade_plugin_with_github(
|
||||||
@ -457,7 +339,7 @@ class PluginService:
|
|||||||
"""
|
"""
|
||||||
PluginService._check_marketplace_only_permission()
|
PluginService._check_marketplace_only_permission()
|
||||||
manager = PluginInstaller()
|
manager = PluginInstaller()
|
||||||
result = manager.upgrade_plugin(
|
return manager.upgrade_plugin(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
original_plugin_unique_identifier,
|
original_plugin_unique_identifier,
|
||||||
new_plugin_unique_identifier,
|
new_plugin_unique_identifier,
|
||||||
@ -468,8 +350,6 @@ class PluginService:
|
|||||||
"package": package,
|
"package": package,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
|
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
|
||||||
@ -535,14 +415,12 @@ class PluginService:
|
|||||||
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||||
PluginService._check_plugin_installation_scope(resp.verification)
|
PluginService._check_plugin_installation_scope(resp.verification)
|
||||||
|
|
||||||
result = manager.install_from_identifiers(
|
return manager.install_from_identifiers(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
plugin_unique_identifiers,
|
plugin_unique_identifiers,
|
||||||
PluginInstallationSource.Package,
|
PluginInstallationSource.Package,
|
||||||
[{}],
|
[{}],
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
|
||||||
@ -556,7 +434,7 @@ class PluginService:
|
|||||||
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
|
||||||
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
|
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
|
||||||
|
|
||||||
result = manager.install_from_identifiers(
|
return manager.install_from_identifiers(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
[plugin_unique_identifier],
|
[plugin_unique_identifier],
|
||||||
PluginInstallationSource.Github,
|
PluginInstallationSource.Github,
|
||||||
@ -568,8 +446,6 @@ class PluginService:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||||
@ -637,14 +513,12 @@ class PluginService:
|
|||||||
actual_plugin_unique_identifiers.append(response.unique_identifier)
|
actual_plugin_unique_identifiers.append(response.unique_identifier)
|
||||||
metas.append({"plugin_unique_identifier": response.unique_identifier})
|
metas.append({"plugin_unique_identifier": response.unique_identifier})
|
||||||
|
|
||||||
result = manager.install_from_identifiers(
|
return manager.install_from_identifiers(
|
||||||
tenant_id,
|
tenant_id,
|
||||||
actual_plugin_unique_identifiers,
|
actual_plugin_unique_identifiers,
|
||||||
PluginInstallationSource.Marketplace,
|
PluginInstallationSource.Marketplace,
|
||||||
metas,
|
metas,
|
||||||
)
|
)
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
|
||||||
@ -655,10 +529,7 @@ class PluginService:
|
|||||||
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
|
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
|
||||||
|
|
||||||
if not plugin:
|
if not plugin:
|
||||||
result = manager.uninstall(tenant_id, plugin_installation_id)
|
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||||
if result:
|
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
if dify_config.ENTERPRISE_ENABLED:
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
PluginManagerService.try_pre_uninstall_plugin(
|
PluginManagerService.try_pre_uninstall_plugin(
|
||||||
@ -688,39 +559,37 @@ class PluginService:
|
|||||||
|
|
||||||
if not credential_ids:
|
if not credential_ids:
|
||||||
logger.info("No credentials found for plugin: %s", plugin_id)
|
logger.info("No credentials found for plugin: %s", plugin_id)
|
||||||
else:
|
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||||
provider_ids = session.scalars(
|
|
||||||
select(Provider.id).where(
|
|
||||||
Provider.tenant_id == tenant_id,
|
|
||||||
Provider.provider_name.like(f"{plugin_id}/%"),
|
|
||||||
Provider.credential_id.in_(credential_ids),
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
|
provider_ids = session.scalars(
|
||||||
|
select(Provider.id).where(
|
||||||
for provider_id in provider_ids:
|
Provider.tenant_id == tenant_id,
|
||||||
ProviderCredentialsCache(
|
Provider.provider_name.like(f"{plugin_id}/%"),
|
||||||
tenant_id=tenant_id,
|
Provider.credential_id.in_(credential_ids),
|
||||||
identity_id=provider_id,
|
|
||||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
|
||||||
).delete()
|
|
||||||
|
|
||||||
session.execute(
|
|
||||||
delete(ProviderCredential).where(
|
|
||||||
ProviderCredential.id.in_(credential_ids),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
logger.info(
|
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
|
||||||
"Completed deleting credentials and cleaning provider associations for plugin: %s",
|
|
||||||
plugin_id,
|
for provider_id in provider_ids:
|
||||||
|
ProviderCredentialsCache(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
identity_id=provider_id,
|
||||||
|
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
session.execute(
|
||||||
|
delete(ProviderCredential).where(
|
||||||
|
ProviderCredential.id.in_(credential_ids),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
result = manager.uninstall(tenant_id, plugin_installation_id)
|
logger.info(
|
||||||
if result:
|
"Completed deleting credentials and cleaning provider associations for plugin: %s",
|
||||||
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
|
plugin_id,
|
||||||
return result
|
)
|
||||||
|
|
||||||
|
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||||
@ -12,7 +12,6 @@ from sqlalchemy import select
|
|||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants import DOCUMENT_EXTENSIONS
|
from constants import DOCUMENT_EXTENSIONS
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -23,6 +22,7 @@ from models.model import UploadFile
|
|||||||
from models.workflow import Workflow, WorkflowType
|
from models.workflow import Workflow, WorkflowType
|
||||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
|
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
|
||||||
from services.plugin.plugin_migration import PluginMigration
|
from services.plugin.plugin_migration import PluginMigration
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from core.helper.name_generator import generate_incremental_name
|
|||||||
from core.helper.position_helper import is_filtered
|
from core.helper.position_helper import is_filtered
|
||||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||||
from core.tools.entities.api_entities import (
|
from core.tools.entities.api_entities import (
|
||||||
@ -32,6 +31,7 @@ from extensions.ext_database import db
|
|||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.provider_ids import ToolProviderID
|
from models.provider_ids import ToolProviderID
|
||||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from configs import dify_config
|
|||||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||||
from core.mcp.types import Tool as MCPTool
|
from core.mcp.types import Tool as MCPTool
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
@ -28,6 +27,7 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
|
|||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
|
|||||||
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
|
||||||
from core.plugin.entities.plugin_daemon import CredentialType
|
from core.plugin.entities.plugin_daemon import CredentialType
|
||||||
from core.plugin.impl.oauth import OAuthHandler
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.utils.system_encryption import decrypt_system_params
|
from core.tools.utils.system_encryption import decrypt_system_params
|
||||||
from core.trigger.entities.api_entities import (
|
from core.trigger.entities.api_entities import (
|
||||||
TriggerProviderApiEntity,
|
TriggerProviderApiEntity,
|
||||||
@ -38,6 +37,7 @@ from models.trigger import (
|
|||||||
TriggerSubscription,
|
TriggerSubscription,
|
||||||
WorkflowPluginTrigger,
|
WorkflowPluginTrigger,
|
||||||
)
|
)
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from models import Account, AccountStatus
|
|||||||
from models.model import App, EndUser, Site
|
from models.model import App, EndUser, Site
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.app_service import AppService
|
from services.app_service import AppService
|
||||||
from services.enterprise.enterprise_service import EnterpriseService
|
from services.enterprise.enterprise_service import PERMISSION_CHECK_MODES, EnterpriseService, WebAppAccessMode
|
||||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||||
|
|
||||||
@ -137,12 +137,8 @@ class WebAppAuthService:
|
|||||||
"""
|
"""
|
||||||
Check if the app requires permission check based on its access mode.
|
Check if the app requires permission check based on its access mode.
|
||||||
"""
|
"""
|
||||||
modes_requiring_permission_check = [
|
|
||||||
"private",
|
|
||||||
"private_all",
|
|
||||||
]
|
|
||||||
if access_mode:
|
if access_mode:
|
||||||
return access_mode in modes_requiring_permission_check
|
return access_mode in PERMISSION_CHECK_MODES
|
||||||
|
|
||||||
if not app_code and not app_id:
|
if not app_code and not app_id:
|
||||||
raise ValueError("Either app_code or app_id must be provided.")
|
raise ValueError("Either app_code or app_id must be provided.")
|
||||||
@ -153,7 +149,7 @@ class WebAppAuthService:
|
|||||||
raise ValueError("App ID could not be determined from the provided app_code.")
|
raise ValueError("App ID could not be determined from the provided app_code.")
|
||||||
|
|
||||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||||
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
|
if webapp_settings and webapp_settings.access_mode in PERMISSION_CHECK_MODES:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -166,11 +162,11 @@ class WebAppAuthService:
|
|||||||
raise ValueError("Either app_code or access_mode must be provided.")
|
raise ValueError("Either app_code or access_mode must be provided.")
|
||||||
|
|
||||||
if access_mode:
|
if access_mode:
|
||||||
if access_mode == "public":
|
if access_mode == WebAppAccessMode.PUBLIC:
|
||||||
return WebAppAuthType.PUBLIC
|
return WebAppAuthType.PUBLIC
|
||||||
elif access_mode in ["private", "private_all"]:
|
elif access_mode in PERMISSION_CHECK_MODES:
|
||||||
return WebAppAuthType.INTERNAL
|
return WebAppAuthType.INTERNAL
|
||||||
elif access_mode == "sso_verified":
|
elif access_mode == WebAppAccessMode.SSO_VERIFIED:
|
||||||
return WebAppAuthType.EXTERNAL
|
return WebAppAuthType.EXTERNAL
|
||||||
|
|
||||||
if app_code:
|
if app_code:
|
||||||
|
|||||||
@ -6,11 +6,11 @@ from typing import Any, TypedDict
|
|||||||
from sqlalchemy import and_, func, or_, select
|
from sqlalchemy import and_, func, or_, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from graphon.enums import WorkflowExecutionStatus
|
from graphon.enums import WorkflowExecutionStatus
|
||||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
|
||||||
from models.enums import AppTriggerType, CreatorUserRole
|
from models.enums import AppTriggerType, CreatorUserRole
|
||||||
from models.trigger import WorkflowTriggerLog
|
from models.trigger import WorkflowTriggerLog
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
from services.workflow.entities import TriggerMetadata
|
from services.workflow.entities import TriggerMetadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -162,12 +162,18 @@ class _AppRunner:
|
|||||||
user = self._resolve_user()
|
user = self._resolve_user()
|
||||||
|
|
||||||
with self._setup_flask_context(user):
|
with self._setup_flask_context(user):
|
||||||
response = self._run_app(
|
try:
|
||||||
app=app,
|
response = self._run_app(
|
||||||
workflow=workflow,
|
app=app,
|
||||||
user=user,
|
workflow=workflow,
|
||||||
pause_state_config=pause_config,
|
user=user,
|
||||||
)
|
pause_state_config=pause_config,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if exec_params.streaming:
|
||||||
|
_publish_error_event(exc, exec_params.workflow_run_id, exec_params.app_mode)
|
||||||
|
raise
|
||||||
|
|
||||||
if not exec_params.streaming:
|
if not exec_params.streaming:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -238,6 +244,12 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun
|
|||||||
return session.get(EndUser, workflow_run.created_by)
|
return session.get(EndUser, workflow_run.created_by)
|
||||||
|
|
||||||
|
|
||||||
|
def _publish_error_event(exc: Exception, workflow_run_id: str, app_mode: AppMode) -> None:
|
||||||
|
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id)
|
||||||
|
payload = json.dumps({"event": "error", "message": str(exc), "status": 500})
|
||||||
|
topic.publish(payload.encode())
|
||||||
|
|
||||||
|
|
||||||
def _publish_streaming_response(
|
def _publish_streaming_response(
|
||||||
response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None],
|
response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None],
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
|
|||||||
@ -9,9 +9,9 @@ from celery import shared_task
|
|||||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||||
from core.plugin.entities.plugin import PluginInstallationSource
|
from core.plugin.entities.plugin import PluginInstallationSource
|
||||||
from core.plugin.impl.plugin import PluginInstaller
|
from core.plugin.impl.plugin import PluginInstaller
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from models.account import TenantPluginAutoUpgradeStrategy
|
from models.account import TenantPluginAutoUpgradeStrategy
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
"""Shared fixtures for /openapi/v1/* integration tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||||
|
from models.account import AccountStatus
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256(token: str) -> str:
|
||||||
|
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def disable_enterprise(monkeypatch):
|
||||||
|
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
|
||||||
|
EE branch override this with their own monkeypatch in-test."""
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
|
||||||
|
with flask_app.app_context():
|
||||||
|
tenant = Tenant(name="t1", status="normal")
|
||||||
|
account = Account(email="u@example.com", name="u")
|
||||||
|
db.session.add_all([tenant, account])
|
||||||
|
db.session.commit()
|
||||||
|
account.status = AccountStatus.ACTIVE
|
||||||
|
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
|
||||||
|
db.session.add(join)
|
||||||
|
db.session.commit()
|
||||||
|
yield account, tenant, join
|
||||||
|
db.session.delete(join)
|
||||||
|
db.session.delete(account)
|
||||||
|
db.session.delete(tenant)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
|
||||||
|
_, tenant, _ = workspace_account
|
||||||
|
with flask_app.app_context():
|
||||||
|
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
|
||||||
|
db.session.add(app)
|
||||||
|
db.session.commit()
|
||||||
|
yield app
|
||||||
|
db.session.delete(app)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mint_token(flask_app: Flask):
|
||||||
|
"""Factory fixture; tracks minted rows and deletes them on teardown so
|
||||||
|
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
|
||||||
|
minted: list[OAuthAccessToken] = []
|
||||||
|
|
||||||
|
def _mint(
|
||||||
|
token: str,
|
||||||
|
*,
|
||||||
|
account_id: str | None,
|
||||||
|
prefix: str,
|
||||||
|
subject_email: str,
|
||||||
|
subject_issuer: str | None,
|
||||||
|
) -> OAuthAccessToken:
|
||||||
|
with flask_app.app_context():
|
||||||
|
row = OAuthAccessToken(
|
||||||
|
token_hash=_sha256(token),
|
||||||
|
prefix=prefix,
|
||||||
|
account_id=account_id,
|
||||||
|
subject_email=subject_email,
|
||||||
|
subject_issuer=subject_issuer,
|
||||||
|
client_id="difyctl",
|
||||||
|
device_label="test-device",
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
db.session.add(row)
|
||||||
|
db.session.commit()
|
||||||
|
minted.append(row)
|
||||||
|
return row
|
||||||
|
|
||||||
|
yield _mint
|
||||||
|
|
||||||
|
with flask_app.app_context():
|
||||||
|
for row in minted:
|
||||||
|
db.session.delete(db.session.merge(row))
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def account_token(workspace_account, mint_token) -> str:
|
||||||
|
account, _, _ = workspace_account
|
||||||
|
token = "dfoa_" + uuid.uuid4().hex
|
||||||
|
mint_token(
|
||||||
|
token,
|
||||||
|
account_id=account.id,
|
||||||
|
prefix="dfoa_",
|
||||||
|
subject_email=account.email,
|
||||||
|
subject_issuer="dify:account",
|
||||||
|
)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
|
||||||
|
def _flush():
|
||||||
|
with flask_app.app_context():
|
||||||
|
for k in redis_client.keys("auth:*"):
|
||||||
|
redis_client.delete(k)
|
||||||
|
for k in redis_client.keys("rl:*"):
|
||||||
|
redis_client.delete(k)
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
yield
|
||||||
|
_flush()
|
||||||
238
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
238
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import App
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||||
|
captured["mode"] = app_model.mode
|
||||||
|
captured["args"] = args
|
||||||
|
captured["invoke_from"] = invoke_from
|
||||||
|
return {
|
||||||
|
"event": "message",
|
||||||
|
"task_id": "t",
|
||||||
|
"id": "m",
|
||||||
|
"message_id": "m",
|
||||||
|
"conversation_id": "c",
|
||||||
|
"mode": "chat",
|
||||||
|
"answer": "ok",
|
||||||
|
"created_at": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.get_json()["mode"] == "chat"
|
||||||
|
assert captured["mode"] == "chat"
|
||||||
|
assert captured["invoke_from"] == InvokeFrom.OPENAPI
|
||||||
|
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_with_mode(flask_app: Flask, workspace_account):
|
||||||
|
"""Factory that creates an App row in the workspace_account tenant with
|
||||||
|
a specified mode. Tracks rows for teardown.
|
||||||
|
"""
|
||||||
|
_, tenant, _ = workspace_account
|
||||||
|
created: list[App] = []
|
||||||
|
|
||||||
|
def _make(mode: str) -> App:
|
||||||
|
with flask_app.app_context():
|
||||||
|
app = App(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
name=f"a-{mode}",
|
||||||
|
mode=mode,
|
||||||
|
status="normal",
|
||||||
|
enable_site=True,
|
||||||
|
enable_api=True,
|
||||||
|
)
|
||||||
|
db.session.add(app)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(app)
|
||||||
|
db.session.expunge(app)
|
||||||
|
created.append(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
yield _make
|
||||||
|
|
||||||
|
with flask_app.app_context():
|
||||||
|
for app in created:
|
||||||
|
db.session.delete(db.session.merge(app))
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"inputs": {}, "response_mode": "blocking"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
assert b"query_required_for_chat" in res.data
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||||
|
app = app_with_mode("completion")
|
||||||
|
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||||
|
captured["mode"] = app_model.mode
|
||||||
|
captured["args"] = args
|
||||||
|
return {
|
||||||
|
"event": "message",
|
||||||
|
"task_id": "t",
|
||||||
|
"id": "m",
|
||||||
|
"message_id": "m",
|
||||||
|
"mode": "completion",
|
||||||
|
"answer": "ok",
|
||||||
|
"created_at": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app.id}/run",
|
||||||
|
json={"inputs": {}, "response_mode": "blocking"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.get_json()["mode"] == "completion"
|
||||||
|
assert captured["mode"] == "completion"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||||
|
app = app_with_mode("workflow")
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app.id}/run",
|
||||||
|
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
assert b"query_not_supported_for_workflow" in res.data
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||||
|
app = app_with_mode("workflow")
|
||||||
|
|
||||||
|
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||||
|
return {
|
||||||
|
"workflow_run_id": "wfr",
|
||||||
|
"task_id": "t",
|
||||||
|
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app.id}/run",
|
||||||
|
json={"inputs": {}, "response_mode": "blocking"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.get_json()
|
||||||
|
assert body["mode"] == "workflow"
|
||||||
|
assert body["workflow_run_id"] == "wfr"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||||
|
app = app_with_mode("channel")
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app.id}/run",
|
||||||
|
json={"inputs": {}, "response_mode": "blocking"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
assert b"mode_not_runnable" in res.data
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"inputs": {}, "query": "hi"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||||
|
"""Stub the authenticator to return an AuthContext with empty scopes."""
|
||||||
|
from libs import oauth_bearer
|
||||||
|
|
||||||
|
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
|
||||||
|
|
||||||
|
def _stub_authenticate(self, token: str):
|
||||||
|
ctx = real_authenticate(self, token)
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
|
return replace(ctx, scopes=frozenset())
|
||||||
|
|
||||||
|
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
|
||||||
|
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"inputs": {}, "query": "hi"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{uuid.uuid4()}/run",
|
||||||
|
json={"inputs": {}, "query": "hi"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||||
|
def _stream() -> Generator[str, None, None]:
|
||||||
|
yield 'event: message\ndata: {"x": 1}\n\n'
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"controllers.openapi.app_run.AppGenerateService.generate",
|
||||||
|
staticmethod(lambda **kw: _stream()),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.headers["Content-Type"].startswith("text/event-stream")
|
||||||
|
assert b"event: message" in res.data
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
|
||||||
|
client = flask_app.test_client()
|
||||||
|
res = client.post(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||||
|
json={"query": "hi"},
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
"""Integration tests for /openapi/v1/apps* read surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from flask.testing import FlaskClient
|
||||||
|
|
||||||
|
from models import App
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
|
||||||
|
resp = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
|
||||||
|
resp = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
|
||||||
|
resp = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_returns_merged_shape(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"]["id"] == app_in_workspace.id
|
||||||
|
assert body["info"]["mode"] == "chat"
|
||||||
|
assert isinstance(body["parameters"], dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_full_includes_input_schema(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"] is not None
|
||||||
|
assert body["parameters"] is not None
|
||||||
|
assert body["input_schema"] is not None
|
||||||
|
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_info_only(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"] is not None
|
||||||
|
assert body["parameters"] is None
|
||||||
|
assert body["input_schema"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_parameters_only(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"] is None
|
||||||
|
assert body["parameters"] is not None
|
||||||
|
assert body["input_schema"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_input_schema_only(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"] is None
|
||||||
|
assert body["parameters"] is None
|
||||||
|
assert body["input_schema"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_combined(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["info"] is not None
|
||||||
|
assert body["parameters"] is None
|
||||||
|
assert body["input_schema"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_unknown_returns_422(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_describe_fields_extra_param_returns_422(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_list_returns_pagination_envelope(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
workspace_account,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
_, tenant, _ = workspace_account
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
assert body["page"] == 1
|
||||||
|
assert body["limit"] == 20
|
||||||
|
assert body["total"] >= 1
|
||||||
|
assert any(d["id"] == app_in_workspace.id for d in body["data"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
|
||||||
|
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
|
||||||
|
assert res.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_apps_list_tag_no_match_returns_empty_data_not_400(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
workspace_account,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
_, tenant, _ = workspace_account
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json["data"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_account_sessions_returns_envelope(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
account_token: str,
|
||||||
|
):
|
||||||
|
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json
|
||||||
|
# canonical envelope shape
|
||||||
|
assert isinstance(body["data"], list)
|
||||||
|
assert "page" in body
|
||||||
|
assert "limit" in body
|
||||||
|
assert "total" in body
|
||||||
|
assert "has_more" in body
|
||||||
|
# the bearer's own minted session must appear
|
||||||
|
assert any(s["prefix"] == "dfoa_" for s in body["data"])
|
||||||
|
# legacy "sessions" key must NOT appear
|
||||||
|
assert "sessions" not in body
|
||||||
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
"""Integration tests for the /openapi/v1 bearer auth surface.
|
||||||
|
|
||||||
|
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
|
||||||
|
acceptance/rejection on app-scoped routes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask.testing import FlaskClient
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models import App, Tenant
|
||||||
|
|
||||||
|
|
||||||
|
def test_info_accepts_account_bearer_with_apps_read_scope(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
) -> None:
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json["id"] == app_in_workspace.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
|
||||||
|
"""A fresh app under a *different* tenant — caller has no membership row."""
|
||||||
|
with flask_app.app_context():
|
||||||
|
other_tenant = Tenant(name="other", status="normal")
|
||||||
|
db.session.add(other_tenant)
|
||||||
|
db.session.commit()
|
||||||
|
app = App(
|
||||||
|
tenant_id=other_tenant.id,
|
||||||
|
name="b",
|
||||||
|
mode="chat",
|
||||||
|
status="normal",
|
||||||
|
enable_site=True,
|
||||||
|
enable_api=True,
|
||||||
|
)
|
||||||
|
db.session.add(app)
|
||||||
|
db.session.commit()
|
||||||
|
yield app
|
||||||
|
db.session.delete(app)
|
||||||
|
db.session.delete(other_tenant)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_denies_account_bearer_without_membership(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
account_token: str,
|
||||||
|
other_workspace_app: App,
|
||||||
|
) -> None:
|
||||||
|
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 403
|
||||||
|
assert res.json.get("message") == "workspace_membership_revoked"
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer0_skipped_when_enterprise_enabled(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
account_token: str,
|
||||||
|
other_workspace_app: App,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
|
||||||
|
|
||||||
|
/info uses validate_bearer + require_workspace_member inline (no
|
||||||
|
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
|
||||||
|
gets 200 — gateway is expected to enforce isolation upstream.
|
||||||
|
"""
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
# Override the conftest autouse default for this test only.
|
||||||
|
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
|
||||||
|
|
||||||
|
res = test_client.get(
|
||||||
|
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||||
|
headers={"Authorization": f"Bearer {account_token}"},
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json.get("message") != "workspace_membership_revoked"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_returns_429_after_60_requests(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
account_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
|
||||||
|
headers = {"Authorization": f"Bearer {account_token}"}
|
||||||
|
for i in range(60):
|
||||||
|
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||||
|
assert r.status_code == 200, f"unexpected fail at i={i}"
|
||||||
|
|
||||||
|
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||||
|
assert r.status_code == 429
|
||||||
|
assert r.headers.get("Retry-After"), "Retry-After header missing"
|
||||||
|
assert int(r.headers["Retry-After"]) >= 1
|
||||||
|
body = r.json or {}
|
||||||
|
assert body.get("error") == "rate_limited"
|
||||||
|
assert isinstance(body.get("retry_after_ms"), int)
|
||||||
|
assert body["retry_after_ms"] >= 1000
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_bucket_shared_across_surfaces(
|
||||||
|
test_client: FlaskClient,
|
||||||
|
app_in_workspace: App,
|
||||||
|
account_token: str,
|
||||||
|
) -> None:
|
||||||
|
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
|
||||||
|
headers = {"Authorization": f"Bearer {account_token}"}
|
||||||
|
for _ in range(30):
|
||||||
|
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
|
||||||
|
for _ in range(30):
|
||||||
|
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
|
||||||
|
|
||||||
|
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||||
|
assert r.status_code == 429
|
||||||
@ -1,4 +1,4 @@
|
|||||||
"""Tests for core.plugin.plugin_service.PluginService.
|
"""Tests for services.plugin.plugin_service.PluginService.
|
||||||
|
|
||||||
Covers: version caching with Redis, install permission/scope gates,
|
Covers: version caching with Redis, install permission/scope gates,
|
||||||
icon URL construction, asset retrieval with MIME guessing, plugin
|
icon URL construction, asset retrieval with MIME guessing, plugin
|
||||||
@ -17,11 +17,11 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from core.plugin.entities.plugin import PluginInstallationSource
|
from core.plugin.entities.plugin import PluginInstallationSource
|
||||||
from core.plugin.entities.plugin_daemon import PluginVerification
|
from core.plugin.entities.plugin_daemon import PluginVerification
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from models import ProviderType
|
from models import ProviderType
|
||||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||||
from services.errors.plugin import PluginInstallationForbiddenError
|
from services.errors.plugin import PluginInstallationForbiddenError
|
||||||
from services.feature_service import PluginInstallationScope
|
from services.feature_service import PluginInstallationScope
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
|
|
||||||
def _make_features(
|
def _make_features(
|
||||||
@ -35,8 +35,8 @@ def _make_features(
|
|||||||
|
|
||||||
|
|
||||||
class TestFetchLatestPluginVersion:
|
class TestFetchLatestPluginVersion:
|
||||||
@patch("core.plugin.plugin_service.marketplace")
|
@patch("services.plugin.plugin_service.marketplace")
|
||||||
@patch("core.plugin.plugin_service.redis_client")
|
@patch("services.plugin.plugin_service.redis_client")
|
||||||
def test_returns_cached_version(self, mock_redis, mock_marketplace):
|
def test_returns_cached_version(self, mock_redis, mock_marketplace):
|
||||||
cached_json = PluginService.LatestPluginCache(
|
cached_json = PluginService.LatestPluginCache(
|
||||||
plugin_id="p1",
|
plugin_id="p1",
|
||||||
@ -53,8 +53,8 @@ class TestFetchLatestPluginVersion:
|
|||||||
assert result["p1"].version == "1.0.0"
|
assert result["p1"].version == "1.0.0"
|
||||||
mock_marketplace.batch_fetch_plugin_manifests.assert_not_called()
|
mock_marketplace.batch_fetch_plugin_manifests.assert_not_called()
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.marketplace")
|
@patch("services.plugin.plugin_service.marketplace")
|
||||||
@patch("core.plugin.plugin_service.redis_client")
|
@patch("services.plugin.plugin_service.redis_client")
|
||||||
def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace):
|
def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
manifest = MagicMock()
|
manifest = MagicMock()
|
||||||
@ -71,8 +71,8 @@ class TestFetchLatestPluginVersion:
|
|||||||
assert result["p1"].version == "2.0.0"
|
assert result["p1"].version == "2.0.0"
|
||||||
mock_redis.setex.assert_called_once()
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.marketplace")
|
@patch("services.plugin.plugin_service.marketplace")
|
||||||
@patch("core.plugin.plugin_service.redis_client")
|
@patch("services.plugin.plugin_service.redis_client")
|
||||||
def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace):
|
def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
mock_marketplace.batch_fetch_plugin_manifests.return_value = []
|
mock_marketplace.batch_fetch_plugin_manifests.return_value = []
|
||||||
@ -81,8 +81,8 @@ class TestFetchLatestPluginVersion:
|
|||||||
|
|
||||||
assert result["unknown"] is None
|
assert result["unknown"] is None
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.marketplace")
|
@patch("services.plugin.plugin_service.marketplace")
|
||||||
@patch("core.plugin.plugin_service.redis_client")
|
@patch("services.plugin.plugin_service.redis_client")
|
||||||
def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace):
|
def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace):
|
||||||
mock_redis.get.return_value = None
|
mock_redis.get.return_value = None
|
||||||
mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error")
|
mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error")
|
||||||
@ -93,14 +93,14 @@ class TestFetchLatestPluginVersion:
|
|||||||
|
|
||||||
|
|
||||||
class TestCheckMarketplaceOnlyPermission:
|
class TestCheckMarketplaceOnlyPermission:
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_raises_when_restricted(self, mock_fs):
|
def test_raises_when_restricted(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True)
|
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True)
|
||||||
|
|
||||||
with pytest.raises(PluginInstallationForbiddenError):
|
with pytest.raises(PluginInstallationForbiddenError):
|
||||||
PluginService._check_marketplace_only_permission()
|
PluginService._check_marketplace_only_permission()
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_passes_when_not_restricted(self, mock_fs):
|
def test_passes_when_not_restricted(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False)
|
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False)
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class TestCheckMarketplaceOnlyPermission:
|
|||||||
|
|
||||||
|
|
||||||
class TestCheckPluginInstallationScope:
|
class TestCheckPluginInstallationScope:
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_official_only_allows_langgenius(self, mock_fs):
|
def test_official_only_allows_langgenius(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
||||||
verification = MagicMock()
|
verification = MagicMock()
|
||||||
@ -116,14 +116,14 @@ class TestCheckPluginInstallationScope:
|
|||||||
|
|
||||||
PluginService._check_plugin_installation_scope(verification) # should not raise
|
PluginService._check_plugin_installation_scope(verification) # should not raise
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_official_only_rejects_third_party(self, mock_fs):
|
def test_official_only_rejects_third_party(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
|
||||||
|
|
||||||
with pytest.raises(PluginInstallationForbiddenError):
|
with pytest.raises(PluginInstallationForbiddenError):
|
||||||
PluginService._check_plugin_installation_scope(None)
|
PluginService._check_plugin_installation_scope(None)
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_official_and_partners_allows_partner(self, mock_fs):
|
def test_official_and_partners_allows_partner(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(
|
mock_fs.get_system_features.return_value = _make_features(
|
||||||
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
||||||
@ -133,7 +133,7 @@ class TestCheckPluginInstallationScope:
|
|||||||
|
|
||||||
PluginService._check_plugin_installation_scope(verification) # should not raise
|
PluginService._check_plugin_installation_scope(verification) # should not raise
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_official_and_partners_rejects_none(self, mock_fs):
|
def test_official_and_partners_rejects_none(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(
|
mock_fs.get_system_features.return_value = _make_features(
|
||||||
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
|
||||||
@ -142,7 +142,7 @@ class TestCheckPluginInstallationScope:
|
|||||||
with pytest.raises(PluginInstallationForbiddenError):
|
with pytest.raises(PluginInstallationForbiddenError):
|
||||||
PluginService._check_plugin_installation_scope(None)
|
PluginService._check_plugin_installation_scope(None)
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_none_scope_always_raises(self, mock_fs):
|
def test_none_scope_always_raises(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE)
|
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE)
|
||||||
verification = MagicMock()
|
verification = MagicMock()
|
||||||
@ -151,7 +151,7 @@ class TestCheckPluginInstallationScope:
|
|||||||
with pytest.raises(PluginInstallationForbiddenError):
|
with pytest.raises(PluginInstallationForbiddenError):
|
||||||
PluginService._check_plugin_installation_scope(verification)
|
PluginService._check_plugin_installation_scope(verification)
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
def test_all_scope_passes_any(self, mock_fs):
|
def test_all_scope_passes_any(self, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL)
|
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL)
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class TestCheckPluginInstallationScope:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetPluginIconUrl:
|
class TestGetPluginIconUrl:
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_constructs_url_with_params(self, mock_config):
|
def test_constructs_url_with_params(self, mock_config):
|
||||||
mock_config.CONSOLE_API_URL = "https://console.example.com"
|
mock_config.CONSOLE_API_URL = "https://console.example.com"
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ class TestGetPluginIconUrl:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetAsset:
|
class TestGetAsset:
|
||||||
@patch("core.plugin.plugin_service.PluginAssetManager")
|
@patch("services.plugin.plugin_service.PluginAssetManager")
|
||||||
def test_returns_bytes_and_guessed_mime(self, mock_asset_cls):
|
def test_returns_bytes_and_guessed_mime(self, mock_asset_cls):
|
||||||
mock_asset_cls.return_value.fetch_asset.return_value = b"<svg/>"
|
mock_asset_cls.return_value.fetch_asset.return_value = b"<svg/>"
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ class TestGetAsset:
|
|||||||
assert data == b"<svg/>"
|
assert data == b"<svg/>"
|
||||||
assert "svg" in mime
|
assert "svg" in mime
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.PluginAssetManager")
|
@patch("services.plugin.plugin_service.PluginAssetManager")
|
||||||
def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls):
|
def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls):
|
||||||
mock_asset_cls.return_value.fetch_asset.return_value = b"\x00"
|
mock_asset_cls.return_value.fetch_asset.return_value = b"\x00"
|
||||||
|
|
||||||
@ -190,13 +190,13 @@ class TestGetAsset:
|
|||||||
|
|
||||||
|
|
||||||
class TestIsPluginVerified:
|
class TestIsPluginVerified:
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_returns_true_when_verified(self, mock_installer_cls):
|
def test_returns_true_when_verified(self, mock_installer_cls):
|
||||||
mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True
|
mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True
|
||||||
|
|
||||||
assert PluginService.is_plugin_verified("t1", "uid-1") is True
|
assert PluginService.is_plugin_verified("t1", "uid-1") is True
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_returns_false_on_exception(self, mock_installer_cls):
|
def test_returns_false_on_exception(self, mock_installer_cls):
|
||||||
mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found")
|
mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found")
|
||||||
|
|
||||||
@ -204,24 +204,24 @@ class TestIsPluginVerified:
|
|||||||
|
|
||||||
|
|
||||||
class TestUpgradePluginWithMarketplace:
|
class TestUpgradePluginWithMarketplace:
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_raises_when_marketplace_disabled(self, mock_config):
|
def test_raises_when_marketplace_disabled(self, mock_config):
|
||||||
mock_config.MARKETPLACE_ENABLED = False
|
mock_config.MARKETPLACE_ENABLED = False
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
||||||
PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid")
|
PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid")
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_raises_when_same_identifier(self, mock_config):
|
def test_raises_when_same_identifier(self, mock_config):
|
||||||
mock_config.MARKETPLACE_ENABLED = True
|
mock_config.MARKETPLACE_ENABLED = True
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="same plugin"):
|
with pytest.raises(ValueError, match="same plugin"):
|
||||||
PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid")
|
PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid")
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.marketplace")
|
@patch("services.plugin.plugin_service.marketplace")
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace):
|
def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace):
|
||||||
mock_config.MARKETPLACE_ENABLED = True
|
mock_config.MARKETPLACE_ENABLED = True
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
@ -234,10 +234,10 @@ class TestUpgradePluginWithMarketplace:
|
|||||||
mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid")
|
mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid")
|
||||||
installer.upgrade_plugin.assert_called_once()
|
installer.upgrade_plugin.assert_called_once()
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.download_plugin_pkg")
|
@patch("services.plugin.plugin_service.download_plugin_pkg")
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
||||||
mock_config.MARKETPLACE_ENABLED = True
|
mock_config.MARKETPLACE_ENABLED = True
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
@ -256,8 +256,8 @@ class TestUpgradePluginWithMarketplace:
|
|||||||
|
|
||||||
|
|
||||||
class TestUpgradePluginWithGithub:
|
class TestUpgradePluginWithGithub:
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
|
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
installer = mock_installer_cls.return_value
|
installer = mock_installer_cls.return_value
|
||||||
@ -271,8 +271,8 @@ class TestUpgradePluginWithGithub:
|
|||||||
|
|
||||||
|
|
||||||
class TestUploadPkg:
|
class TestUploadPkg:
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
|
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
upload_resp = MagicMock()
|
upload_resp = MagicMock()
|
||||||
@ -285,17 +285,17 @@ class TestUploadPkg:
|
|||||||
|
|
||||||
|
|
||||||
class TestInstallFromMarketplacePkg:
|
class TestInstallFromMarketplacePkg:
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_raises_when_marketplace_disabled(self, mock_config):
|
def test_raises_when_marketplace_disabled(self, mock_config):
|
||||||
mock_config.MARKETPLACE_ENABLED = False
|
mock_config.MARKETPLACE_ENABLED = False
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
with pytest.raises(ValueError, match="marketplace is not enabled"):
|
||||||
PluginService.install_from_marketplace_pkg("t1", ["uid-1"])
|
PluginService.install_from_marketplace_pkg("t1", ["uid-1"])
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.download_plugin_pkg")
|
@patch("services.plugin.plugin_service.download_plugin_pkg")
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download):
|
||||||
mock_config.MARKETPLACE_ENABLED = True
|
mock_config.MARKETPLACE_ENABLED = True
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
@ -315,9 +315,9 @@ class TestInstallFromMarketplacePkg:
|
|||||||
call_args = installer.install_from_identifiers.call_args[0]
|
call_args = installer.install_from_identifiers.call_args[0]
|
||||||
assert call_args[1] == ["resolved-uid"]
|
assert call_args[1] == ["resolved-uid"]
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.FeatureService")
|
@patch("services.plugin.plugin_service.FeatureService")
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
@patch("core.plugin.plugin_service.dify_config")
|
@patch("services.plugin.plugin_service.dify_config")
|
||||||
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
|
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
|
||||||
mock_config.MARKETPLACE_ENABLED = True
|
mock_config.MARKETPLACE_ENABLED = True
|
||||||
mock_fs.get_system_features.return_value = _make_features()
|
mock_fs.get_system_features.return_value = _make_features()
|
||||||
@ -336,7 +336,7 @@ class TestInstallFromMarketplacePkg:
|
|||||||
|
|
||||||
|
|
||||||
class TestUninstall:
|
class TestUninstall:
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls):
|
def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls):
|
||||||
installer = mock_installer_cls.return_value
|
installer = mock_installer_cls.return_value
|
||||||
installer.list_plugins.return_value = []
|
installer.list_plugins.return_value = []
|
||||||
@ -347,7 +347,7 @@ class TestUninstall:
|
|||||||
assert result is True
|
assert result is True
|
||||||
installer.uninstall.assert_called_once_with("t1", "install-1")
|
installer.uninstall.assert_called_once_with("t1", "install-1")
|
||||||
|
|
||||||
@patch("core.plugin.plugin_service.PluginInstaller")
|
@patch("services.plugin.plugin_service.PluginInstaller")
|
||||||
def test_cleans_credentials_when_plugin_found(
|
def test_cleans_credentials_when_plugin_found(
|
||||||
self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session
|
self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session
|
||||||
):
|
):
|
||||||
@ -389,7 +389,7 @@ class TestUninstall:
|
|||||||
installer.list_plugins.return_value = [plugin]
|
installer.list_plugins.return_value = [plugin]
|
||||||
installer.uninstall.return_value = True
|
installer.uninstall.return_value = True
|
||||||
|
|
||||||
with patch("core.plugin.plugin_service.dify_config") as mock_config:
|
with patch("services.plugin.plugin_service.dify_config") as mock_config:
|
||||||
mock_config.ENTERPRISE_ENABLED = False
|
mock_config.ENTERPRISE_ENABLED = False
|
||||||
result = PluginService.uninstall(tenant_id, "install-1")
|
result = PluginService.uninstall(tenant_id, "install-1")
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import pytest
|
|||||||
from faker import Faker
|
from faker import Faker
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.plugin.plugin_service import PluginService
|
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
@ -21,6 +20,7 @@ from core.tools.entities.tool_entities import (
|
|||||||
ToolProviderType,
|
ToolProviderType,
|
||||||
)
|
)
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class TestToolTransformService:
|
|||||||
def mock_external_service_dependencies(self):
|
def mock_external_service_dependencies(self):
|
||||||
"""Mock setup for external service dependencies."""
|
"""Mock setup for external service dependencies."""
|
||||||
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
|
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
|
||||||
with patch("core.plugin.plugin_service.dify_config", new=mock_dify_config):
|
with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config):
|
||||||
# Setup default mock returns
|
# Setup default mock returns
|
||||||
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
|
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user