mirror of
https://github.com/langgenius/dify.git
synced 2026-05-15 14:37:59 +08:00
Compare commits
9 Commits
4-27-app-d
...
feat/cli
| Author | SHA1 | Date | |
|---|---|---|---|
| c2b91d849d | |||
| e0f4e98a2f | |||
| 9d554495cf | |||
| c2868075fa | |||
| 1a83dfaf1f | |||
| 83d14e0540 | |||
| 1f7da9c191 | |||
| b21d0ae32d | |||
| 6779366dca |
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: how-to-write-component
|
||||
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around abstraction choices, props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
|
||||
description: React/TypeScript component style guide. Use when writing, refactoring, or reviewing React components, especially around props typing, state boundaries, shared local state with Jotai atoms, API types, query/mutation contracts, navigation, memoization, wrappers, and empty-state handling.
|
||||
---
|
||||
|
||||
# How To Write A Component
|
||||
@ -12,7 +12,6 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Search before adding UI, hooks, helpers, or styling patterns. Reuse existing base components, feature components, hooks, utilities, and design styles when they fit.
|
||||
- Group code by feature workflow, route, or ownership area: components, hooks, local types, query helpers, atoms, constants, and small utilities should live near the code that changes with them.
|
||||
- Promote code to shared only when multiple verticals need the same stable primitive. Otherwise keep it local and compose shared primitives inside the owning feature.
|
||||
- Prefer local code and purpose-named helpers over catch-all utility modules; inline cheap derived values when that is clearer.
|
||||
- Use Tailwind CSS v4.1+ rules via the `tailwind-css-rules` skill. Prefer v4 utilities, `gap`, `text-size/line-height`, `min-h-dvh`, and avoid deprecated utilities and `@apply`.
|
||||
|
||||
## Ownership
|
||||
@ -20,8 +19,6 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Put local state, queries, mutations, handlers, and derived UI data in the lowest component that uses them. Extract a purpose-built owner component only when the logic has no natural home.
|
||||
- Repeated TanStack query calls in sibling components are acceptable when each component independently consumes the data. Do not hoist a query only because it is duplicated; TanStack Query handles deduplication and cache sharing.
|
||||
- Hoist state, queries, or callbacks to a parent only when the parent consumes the data, coordinates shared loading/error/empty UI, needs one consistent snapshot, or owns a workflow spanning children.
|
||||
- Pass stable domain identity across boundaries; avoid forwarding derived presentation state when the receiver can derive it from its own data source. A component that owns a visual surface should also own the data access, loading, empty, and error states for content rendered inside it unless a parent truly coordinates that state.
|
||||
- Loading states for visual surfaces should use skeleton placeholders scoped to the content that is actually loading, with shape, density, and dimensions close to the final UI. Avoid generic loading text or centered spinners for page sections, cards, lists, tables, forms, and drawers; reserve spinners for small inline busy indicators such as an in-progress status icon.
|
||||
- Avoid prop drilling. One pass-through layer is acceptable; repeated forwarding means ownership should move down or into feature-scoped Jotai UI state. Keep server/cache state in query and API data flow.
|
||||
- Keep callbacks in a parent only for workflow coordination such as form submission, shared selection, batch behavior, or navigation. Otherwise let the child or row own its action.
|
||||
- Prefer uncontrolled DOM state and CSS variables before adding controlled props.
|
||||
@ -32,9 +29,9 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
- Prefer `function` for top-level components and module helpers. Use arrow functions for local callbacks, handlers, and lambda-style APIs.
|
||||
- Prefer named exports. Use default exports only where the framework requires them, such as Next.js route files.
|
||||
- Type simple one-off props inline. Use a named `Props` type only when reused, exported, complex, or clearer.
|
||||
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers and one-off UI extensions beside the component that needs them.
|
||||
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially persistent IDs and route params. Normalize framework or route params at the boundary.
|
||||
- Keep fallback and invariant checks at the lowest component that already handles that state; avoid defensive fallbacks that mask impossible states.
|
||||
- Use API-generated or API-returned types at component boundaries. Keep small UI conversion helpers beside the component that needs them.
|
||||
- Name values by their domain role and backend API contract, and keep that name stable across the call chain, especially IDs like `appInstanceId`. Normalize framework or route params at the boundary.
|
||||
- Keep fallback and invariant checks at the lowest component that already handles that state; callers should pass raw values through instead of duplicating checks.
|
||||
|
||||
## Queries And Mutations
|
||||
|
||||
@ -51,13 +48,12 @@ Use this as the decision guide for React/TypeScript component structure. Existin
|
||||
## Component Boundaries
|
||||
|
||||
- Use the first level below a page or tab to organize independent page sections when it adds real structure. This layer is layout/semantic first, not automatically the data owner.
|
||||
- Treat component names, semantic roles, and user- or design-marked visual regions as boundary constraints. Do not expand a child component's responsibility just because its data is useful nearby; keep adjacent UI as a sibling owner or introduce a correctly named broader owner.
|
||||
- Split deeper components by the data and state each layer actually needs. Each component should access only necessary data, and ownership should stay at the lowest consumer.
|
||||
- Keep cohesive forms, menu bodies, and one-off helpers local unless they need their own state, reuse, or semantic boundary.
|
||||
- Separate hidden secondary surfaces from the trigger's main flow. For dialogs, dropdowns, popovers, and similar branches, extract a small local component that owns the trigger, open state, and hidden content when it would obscure the parent flow.
|
||||
- Preserve composability by separating behavior ownership from layout ownership. A dropdown action may own its trigger, open state, and menu content; the caller owns placement such as slots, offsets, and alignment.
|
||||
- Avoid unnecessary DOM hierarchy. Do not add wrapper elements unless they provide layout, semantics, accessibility, state ownership, or integration with a library API; prefer fragments or styling an existing element when possible.
|
||||
- Avoid shallow wrappers, layout-only render-prop wrappers, and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
|
||||
- Avoid shallow wrappers and prop renaming unless the wrapper adds validation, orchestration, error handling, state ownership, or a real semantic boundary.
|
||||
|
||||
## You Might Not Need An Effect
|
||||
|
||||
|
||||
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@ -0,0 +1,15 @@
|
||||
**/node_modules
|
||||
**/.pnpm-store
|
||||
**/dist
|
||||
**/.next
|
||||
**/.turbo
|
||||
**/.cache
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/.mypy_cache
|
||||
**/.ruff_cache
|
||||
.git
|
||||
.github
|
||||
*.md
|
||||
!web/README.md
|
||||
!api/README.md
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,6 +18,10 @@
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# CLI
|
||||
/cli/ @langgenius/maintainers
|
||||
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -9,7 +9,6 @@ on:
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-backend"
|
||||
- "4-27-app-deploy"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
63
.github/workflows/cli-docker-build.yml
vendored
Normal file
63
.github/workflows/cli-docker-build.yml
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
name: CLI Docker Build (dev)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "cli/**"
|
||||
- "packages/tsconfig/**"
|
||||
- "pnpm-lock.yaml"
|
||||
- "pnpm-workspace.yaml"
|
||||
merge_group:
|
||||
branches:
|
||||
- "main"
|
||||
types: [checks_requested]
|
||||
|
||||
concurrency:
|
||||
group: cli-docker-build-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build CLI dev image
|
||||
if: github.event_name == 'merge_group' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build CLI Dockerfile.dev
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: "{{defaultContext}}"
|
||||
file: "cli/Dockerfile.dev"
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
||||
build-fork:
|
||||
name: Build CLI dev image (fork)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build CLI Dockerfile.dev
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: "."
|
||||
file: "cli/Dockerfile.dev"
|
||||
platforms: linux/amd64
|
||||
131
.github/workflows/cli-release.yml
vendored
Normal file
131
.github/workflows/cli-release.yml
vendored
Normal file
@ -0,0 +1,131 @@
|
||||
name: CLI Release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_release_tag:
|
||||
description: "dify release tag to attach cli artifacts to (e.g. 1.14.0). Bare semver — dify tags are NOT v-prefixed."
|
||||
type: string
|
||||
required: true
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.event.release.tag_name || inputs.dify_release_tag }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: >-
|
||||
github.repository == 'langgenius/dify' &&
|
||||
(github.event_name == 'workflow_dispatch' ||
|
||||
(vars.CLI_AUTO_RELEASE == 'true' && !github.event.release.prerelease))
|
||||
env:
|
||||
DIFY_TAG: ${{ github.event.release.tag_name || inputs.dify_release_tag }}
|
||||
permissions:
|
||||
contents: write
|
||||
id-token: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Node registry auth
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Bump guard (auto-path only)
|
||||
if: github.event_name == 'release'
|
||||
run: scripts/release-bump-guard.sh
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
NEW_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
NEW_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
|
||||
- name: Verify target dify release exists
|
||||
run: gh release view "$DIFY_TAG" --repo langgenius/dify > /dev/null
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build cli
|
||||
run: |
|
||||
DIFYCTL_VERSION="${{ steps.manifest.outputs.version }}" \
|
||||
DIFYCTL_CHANNEL="${{ steps.manifest.outputs.channel }}" \
|
||||
DIFYCTL_MIN_DIFY="${{ steps.manifest.outputs.minDify }}" \
|
||||
DIFYCTL_MAX_DIFY="${{ steps.manifest.outputs.maxDify }}" \
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build
|
||||
|
||||
- name: Pack tarballs
|
||||
run: pnpm pack:tarballs
|
||||
|
||||
- name: Publish to npm (idempotent)
|
||||
run: scripts/release-npm-publish.sh
|
||||
env:
|
||||
CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
NEW_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
run: scripts/release-write-checksums.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@3454372f43399081ed03b604cb2d021dabca52bb # v3.8.2
|
||||
|
||||
- name: Keyless-sign tarballs + checksum file (Sigstore)
|
||||
run: scripts/release-cosign-sign.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
|
||||
- name: Snapshot tarballs + checksum + signatures as workflow artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: difyctl-${{ steps.manifest.outputs.version }}-${{ env.DIFY_TAG }}
|
||||
path: |
|
||||
cli/dist/difyctl-v*.tar.xz
|
||||
cli/dist/difyctl-v*-checksums.txt
|
||||
cli/dist/difyctl-v*.sig
|
||||
cli/dist/difyctl-v*.pem
|
||||
retention-days: 90
|
||||
if-no-files-found: error
|
||||
|
||||
- name: Upload tarballs + checksum + signatures to dify GH release (idempotent)
|
||||
run: scripts/release-upload-tarballs.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
57
.github/workflows/cli-smoke.yml
vendored
Normal file
57
.github/workflows/cli-smoke.yml
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
name: CLI Smoke (live dify)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_version:
|
||||
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||
type: string
|
||||
required: true
|
||||
cli_ref:
|
||||
description: "Git ref to build the cli from (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
smoke:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Bring up dify
|
||||
env:
|
||||
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||
run: |
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||
docker compose up -d api worker web db redis
|
||||
for i in $(seq 1 60); do
|
||||
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||
echo "dify api ready after ${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run smoke against live dify
|
||||
working-directory: ./cli
|
||||
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||
|
||||
- name: Dump dify logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
cd docker
|
||||
docker compose logs api worker web --tail=200
|
||||
46
.github/workflows/cli-tests.yml
vendored
Normal file
46
.github/workflows/cli-tests.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: CLI Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,6 +42,7 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
@ -63,6 +64,18 @@ jobs:
|
||||
- 'docker/generate_docker_compose'
|
||||
- 'docker/ssrf_proxy/**'
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
cli:
|
||||
- 'cli/**'
|
||||
- 'packages/tsconfig/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- 'eslint.config.mjs'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/cli-tests.yml'
|
||||
- '.github/workflows/cli-docker-build.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
@ -184,6 +197,66 @@ jobs:
|
||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
cli-tests-run:
|
||||
name: Run CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||
uses: ./.github/workflows/cli-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests-skip:
|
||||
name: Skip CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped CLI tests
|
||||
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
if: ${{ always() }}
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
- cli-tests-run
|
||||
- cli-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize CLI Tests status
|
||||
env:
|
||||
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||
run: |
|
||||
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests ran successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests were skipped because no CLI-related files changed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
web-tests-run:
|
||||
name: Run Web Tests
|
||||
needs:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -115,6 +115,12 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||
!/cli/src/env/
|
||||
!/cli/src/commands/env/
|
||||
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||
!/cli/scripts/lib/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
@ -247,6 +253,7 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
*.local.toml
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
|
||||
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -520,6 +520,44 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
OPENAPI_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||
"programmatic clients. Set to true to activate; disabled by default."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -895,6 +933,17 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1181,6 +1230,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -16,7 +16,6 @@ api = ExternalApi(
|
||||
inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/")
|
||||
|
||||
from . import mail as _mail
|
||||
from . import runtime_credentials as _runtime_credentials
|
||||
from .app import dsl as _app_dsl
|
||||
from .plugin import plugin as _plugin
|
||||
from .workspace import workspace as _workspace
|
||||
@ -27,7 +26,6 @@ __all__ = [
|
||||
"_app_dsl",
|
||||
"_mail",
|
||||
"_plugin",
|
||||
"_runtime_credentials",
|
||||
"_workspace",
|
||||
"api",
|
||||
"bp",
|
||||
|
||||
@ -1,129 +0,0 @@
|
||||
"""Inner API endpoints for runtime credential resolution.
|
||||
|
||||
Called by Enterprise while resolving AppRunner runtime artifacts. The endpoint
|
||||
returns decrypted model credentials for in-memory runtime use only.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
from controllers.inner_api import inner_api_ns
|
||||
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||
from core.helper import encrypter
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from extensions.ext_database import db
|
||||
from models.provider import ProviderCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InnerRuntimeModelCredentialResolveItem(BaseModel):
|
||||
credential_id: str = Field(description="Provider credential id")
|
||||
provider: str = Field(description="Runtime provider identifier, for example langgenius/openai/openai")
|
||||
vendor: str | None = Field(default=None, description="Model vendor, for example openai")
|
||||
plugin_unique_identifier: str | None = Field(default=None, description="Runtime plugin identifier")
|
||||
|
||||
|
||||
class InnerRuntimeModelCredentialsResolvePayload(BaseModel):
|
||||
tenant_id: str = Field(description="Workspace id")
|
||||
credentials: list[InnerRuntimeModelCredentialResolveItem] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_model(inner_api_ns, InnerRuntimeModelCredentialsResolvePayload)
|
||||
|
||||
|
||||
@inner_api_ns.route("/enterprise/runtime/model-credentials:resolve")
|
||||
class EnterpriseRuntimeModelCredentialsResolve(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
@inner_api_ns.doc(
|
||||
"enterprise_runtime_model_credentials_resolve",
|
||||
responses={
|
||||
200: "Credentials resolved",
|
||||
400: "Invalid request or credential config",
|
||||
404: "Provider or credential not found",
|
||||
},
|
||||
)
|
||||
@inner_api_ns.expect(inner_api_ns.models[InnerRuntimeModelCredentialsResolvePayload.__name__])
|
||||
def post(self):
|
||||
args = InnerRuntimeModelCredentialsResolvePayload.model_validate(inner_api_ns.payload or {})
|
||||
if not args.credentials:
|
||||
return {"model_credentials": []}, 200
|
||||
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=args.tenant_id)
|
||||
provider_configurations = provider_manager.get_configurations(args.tenant_id)
|
||||
|
||||
resolved: list[dict[str, Any]] = []
|
||||
for item in args.credentials:
|
||||
provider_configuration = provider_configurations.get(item.provider)
|
||||
if provider_configuration is None:
|
||||
return {"message": f"provider '{item.provider}' not found"}, 404
|
||||
|
||||
provider_schema = provider_configuration.provider.provider_credential_schema
|
||||
secret_variables = provider_configuration.extract_secret_variables(
|
||||
provider_schema.credential_form_schemas if provider_schema else []
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == item.credential_id,
|
||||
ProviderCredential.tenant_id == args.tenant_id,
|
||||
ProviderCredential.provider_name.in_(provider_configuration._get_provider_names()),
|
||||
)
|
||||
credential = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if credential is None or not credential.encrypted_config:
|
||||
return {"message": f"credential '{item.credential_id}' not found"}, 404
|
||||
|
||||
try:
|
||||
values = json.loads(credential.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
return {"message": f"credential '{item.credential_id}' has invalid config"}, 400
|
||||
if not isinstance(values, dict):
|
||||
return {"message": f"credential '{item.credential_id}' has invalid config"}, 400
|
||||
|
||||
for key in secret_variables:
|
||||
value = values.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
try:
|
||||
values[key] = encrypter.decrypt_token(tenant_id=args.tenant_id, token=value)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"failed to resolve runtime model credential",
|
||||
extra={
|
||||
"credential_id": item.credential_id,
|
||||
"provider": item.provider,
|
||||
"tenant_id": args.tenant_id,
|
||||
"error": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
return {"message": f"credential '{item.credential_id}' decrypt failed"}, 400
|
||||
|
||||
resolved.append(
|
||||
{
|
||||
"credential_id": item.credential_id,
|
||||
"provider": item.provider,
|
||||
"vendor": item.vendor or _vendor_from_provider(item.provider),
|
||||
"plugin_unique_identifier": item.plugin_unique_identifier,
|
||||
"values": values,
|
||||
}
|
||||
)
|
||||
|
||||
return {"model_credentials": resolved}, 200
|
||||
|
||||
|
||||
def _vendor_from_provider(provider: str) -> str:
|
||||
provider = provider.strip("/")
|
||||
if not provider:
|
||||
return ""
|
||||
return provider.rsplit("/", 1)[-1]
|
||||
122
api/controllers/openapi/__init__.py
Normal file
122
api/controllers/openapi/__init__.py
Normal file
@ -0,0 +1,122 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
AppRunRequest,
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkflowRunResponse,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspacePayload,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunData,
|
||||
WorkflowRunResponse,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
AccountResponse,
|
||||
SessionRow,
|
||||
SessionListResponse,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
# Request models are imported from _models.py and registered above.
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
66
api/controllers/openapi/_audit.py
Normal file
66
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||
|
||||
|
||||
def emit_app_run(
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
caller_kind: str,
|
||||
mode: str,
|
||||
surface: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
surface,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
"surface": surface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_wrong_surface(
|
||||
*,
|
||||
subject_type: str | None,
|
||||
attempted_path: str,
|
||||
client_id: str | None,
|
||||
token_id: str | None,
|
||||
) -> None:
|
||||
logger.warning(
|
||||
"audit: %s subject_type=%s attempted_path=%s",
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
subject_type,
|
||||
attempted_path,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
"subject_type": subject_type,
|
||||
"attempted_path": attempted_path,
|
||||
"client_id": client_id,
|
||||
"token_id": token_id,
|
||||
},
|
||||
)
|
||||
143
api/controllers/openapi/_input_schema.py
Normal file
143
api/controllers/openapi/_input_schema.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
320
api/controllers/openapi/_models.py
Normal file
320
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class PermittedExternalAppsListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
class AccountPayload(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspacePayload(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
subject_type: str
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account: AccountPayload | None = None
|
||||
workspaces: list[WorkspacePayload] = []
|
||||
default_workspace_id: str | None = None
|
||||
|
||||
|
||||
class SessionRow(BaseModel):
|
||||
id: str
|
||||
prefix: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
created_at: str | None = None
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WorkspaceSummaryResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
workspaces: list[WorkspaceSummaryResponse]
|
||||
|
||||
|
||||
class WorkspaceDetailResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class DeviceCodeResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceLookupResponse(BaseModel):
|
||||
valid: bool
|
||||
expires_in_remaining: int = 0
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class DeviceMutateResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
workspace_id: UUIDStrOrEmpty | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class PermittedExternalAppsListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
249
api/controllers/openapi/account.py
Normal file
249
api/controllers/openapi/account.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""User-scoped account endpoints. /account is the bearer-authed
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import and_, select, update
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
_revoke_token_by_id(str(ctx.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = db.session.execute(
|
||||
select(
|
||||
OAuthAccessToken.id,
|
||||
OAuthAccessToken.prefix,
|
||||
OAuthAccessToken.client_id,
|
||||
OAuthAccessToken.device_label,
|
||||
OAuthAccessToken.created_at,
|
||||
OAuthAccessToken.last_used_at,
|
||||
OAuthAccessToken.expires_at,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
*_subject_match(ctx),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
OAuthAccessToken.token_hash.is_not(None),
|
||||
OAuthAccessToken.expires_at > now,
|
||||
)
|
||||
)
|
||||
.order_by(OAuthAccessToken.created_at.desc())
|
||||
).all()
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
SessionRow(
|
||||
id=str(r.id),
|
||||
prefix=r.prefix,
|
||||
client_id=r.client_id,
|
||||
device_label=r.device_label,
|
||||
created_at=_iso(r.created_at),
|
||||
last_used_at=_iso(r.last_used_at),
|
||||
expires_at=_iso(r.expires_at),
|
||||
)
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if owns is None:
|
||||
raise NotFound("session not found")
|
||||
|
||||
_revoke_token_by_id(session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _subject_match(ctx: AuthContext) -> tuple:
|
||||
"""Where-clauses that scope a query to the bearer's subject. Works
|
||||
for both account (account_id) and external_sso (email + issuer).
|
||||
"""
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
return (
|
||||
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||
OAuthAccessToken.account_id.is_(None),
|
||||
)
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _revoke_token_by_id(token_id: str) -> None:
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> WorkspacePayload:
|
||||
join, tenant = row
|
||||
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||
|
||||
|
||||
def _account_payload(account) -> AccountPayload:
|
||||
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||
178
api/controllers/openapi/app_run.py
Normal file
178
api/controllers/openapi/app_run.py
Normal file
@ -0,0 +1,178 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import (
|
||||
AppRunRequest,
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
|
||||
if isinstance(response, tuple):
|
||||
response = response[0]
|
||||
if not isinstance(response, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return response
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
try:
|
||||
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
return blocking_body, 200
|
||||
280
api/controllers/openapi/apps.py
Normal file
280
api/controllers/openapi/apps.py
Normal file
@ -0,0 +1,280 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → sets `g.auth_ctx` before `require_scope` reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = g.auth_ctx
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal" or not is_openapi_visible(app):
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
apply_openapi_gate(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = g.auth_ctx
|
||||
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
if parsed_uuid is not None:
|
||||
app: App = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all",
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
env = AppListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=int(pagination.total),
|
||||
has_more=query.page * query.limit < int(pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
107
api/controllers/openapi/apps_permitted_external.py
Normal file
107
api/controllers/openapi/apps_permitted_external.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||
|
||||
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||
(public / sso_verified). License-gated: CE deploys never enable the
|
||||
EE blueprint chain so this module is unreachable there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
from services.openapi.visibility import apply_openapi_gate
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids))))
|
||||
.scalars()
|
||||
.all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=page_result.total,
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
46
api/controllers/openapi/auth/composition.py
Normal file
46
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
46
api/controllers/openapi/auth/context.py
Normal file
46
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: Scope
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
41
api/controllers/openapi/auth/pipeline.py
Normal file
41
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
172
api/controllers/openapi/auth/steps.py
Normal file
172
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,172 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also assigns
|
||||
``g.auth_ctx`` (the same way ``validate_bearer`` does) so the surface gate
|
||||
+ any handler reading the request-scoped context has a single source of
|
||||
truth across both auth-attach paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from flask import g
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
_extract_bearer, # type: ignore[attr-defined]
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also attaches the resolved `AuthContext` to ``g.auth_ctx`` — same shape
|
||||
the decorator-level ``validate_bearer`` writes — so the surface gate
|
||||
+ downstream readers don't see two different identity sources."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
|
||||
# Single source of truth for the request-scoped identity. Surface
|
||||
# gate + handlers read `g.auth_ctx` regardless of whether the route
|
||||
# ran the decorator path (`validate_bearer`) or the pipeline path.
|
||||
g.auth_ctx = authn
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if `g.auth_ctx.subject_type` is not in `accepted`.
|
||||
|
||||
Delegates to `surface_gate.check_surface` so the inline decorator and
|
||||
the pipeline step emit identical audit events. Relies on `BearerCheck`
|
||||
(above) having set `g.auth_ctx`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
184
api/controllers/openapi/auth/strategies.py
Normal file
184
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = db.session.execute(
|
||||
select(Account).where(Account.email == ctx.subject_email),
|
||||
).scalar_one_or_none()
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
89
api/controllers/openapi/auth/surface_gate.py
Normal file
89
api/controllers/openapi/auth/surface_gate.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that ``g.auth_ctx.subject_type`` is in ``accepted``.
|
||||
|
||||
Raises ``Forbidden`` with ``wrong_surface`` + canonical-path hint on
|
||||
miss; emits ``openapi.wrong_surface_denied`` audit. If ``g.auth_ctx``
|
||||
is missing the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without g.auth_ctx; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
try:
|
||||
return SubjectType(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
9
api/controllers/openapi/index.py
Normal file
9
api/controllers/openapi/index.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
404
api/controllers/openapi/oauth_device.py
Normal file
404
api/controllers/openapi/oauth_device.py
Normal file
@ -0,0 +1,404 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Validation helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
369
api/controllers/openapi/oauth_device_sso.py
Normal file
369
api/controllers/openapi/oauth_device_sso.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from controllers.openapi import bp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims["email"]):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RejectedClaims:
|
||||
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||
|
||||
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||
event without reaching for the full dataclass.
|
||||
"""
|
||||
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _email_belongs_to_dify_account(email: str) -> bool:
|
||||
"""External SSO subjects whose email matches an active Dify Account must
|
||||
authenticate via the internal Dify login path (which mints dfoa_), not via
|
||||
the external SSO device flow. Returning True here blocks dfoe_ minting.
|
||||
|
||||
Pending/uninitialized/banned/closed accounts do not block: pending and
|
||||
uninitialized users may complete invitation via SSO; banned and closed
|
||||
accounts are handled by separate enforcement paths.
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
normalized = email.strip().lower()
|
||||
if not normalized:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(Account.id).where(
|
||||
func.lower(Account.email) == normalized,
|
||||
Account.status == AccountStatus.ACTIVE,
|
||||
),
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
reason,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_rejected",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"reason": reason,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
90
api/controllers/openapi/workspaces.py
Normal file
90
api/controllers/openapi/workspaces.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
rows = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == str(ctx.account_id))
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
row = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
Tenant.id == workspace_id,
|
||||
TenantAccountJoin.account_id == str(ctx.account_id),
|
||||
)
|
||||
).first()
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
)
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||
return WorkspaceDetailResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
)
|
||||
@ -16,7 +16,7 @@ from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
|
||||
!= WebAppAccessMode.PUBLIC
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
@ -730,6 +730,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@ -45,6 +45,7 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
|
||||
)
|
||||
|
||||
|
||||
@ -161,6 +162,8 @@ def create_spec_app() -> Flask:
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.console import console_ns
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web import bp as web_bp
|
||||
@ -169,8 +172,9 @@ def create_spec_app() -> Flask:
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
for namespace in (console_ns, web_ns, service_api_ns):
|
||||
for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
|
||||
for api in namespace.apis:
|
||||
_materialize_inline_model_definitions(api)
|
||||
|
||||
@ -201,6 +205,13 @@ def _registered_models(namespace: str) -> dict[str, object]:
|
||||
for api in service_api_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "openapi":
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
models = dict(openapi_ns.models)
|
||||
for api in openapi_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
|
||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
|
||||
OPENAPI_MAX_AGE_SECONDS: int = 600
|
||||
|
||||
|
||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.mcp import bp as mcp_bp
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
@ -41,6 +44,23 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
if dify_config.OPENAPI_ENABLED:
|
||||
# User-scoped programmatic API. Default empty allowlist = same-origin
|
||||
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
|
||||
# integrations. supports_credentials so cookie-authed approve/deny
|
||||
# work; cross-origin OPTIONS without an allowed origin will fail
|
||||
# the same as on the console blueprint.
|
||||
_apply_cors_once(
|
||||
openapi_bp,
|
||||
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={
|
||||
|
||||
@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token, extract_webapp_passport
|
||||
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
# approval-context) sit under @login_required and authenticate via
|
||||
# the console session cookie. Cookie-only on purpose — bearer
|
||||
# tokens (dfoa_/dfoe_) live on the Authorization header and are
|
||||
# validated by AppPipeline, not flask-login.
|
||||
cookie_token = extract_console_cookie_token(request)
|
||||
if not cookie_token:
|
||||
return None
|
||||
try:
|
||||
decoded = PassportService().verify(cookie_token)
|
||||
except Exception:
|
||||
return None
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
23
api/extensions/ext_oauth_bearer.py
Normal file
23
api/extensions/ext_oauth_bearer.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import build_and_bind
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.ENABLE_OAUTH_BEARER
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
# scoped_session isn't a context manager; request teardown closes it.
|
||||
def session_factory():
|
||||
return db.session
|
||||
|
||||
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||
196
api/libs/device_flow_security.py
Normal file
196
api/libs/device_flow_security.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
|
||||
from flask import Blueprint
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
from libs.token import is_secure
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# enterprise_only decorator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
||||
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
||||
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
||||
|
||||
|
||||
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||
responses don't consume the bucket.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status not in _EE_ENABLED_STATUSES:
|
||||
raise NotFound()
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# approval_grant cookie
|
||||
# ============================================================================
|
||||
|
||||
|
||||
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
||||
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ApprovalGrantClaims:
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
user_code: str
|
||||
nonce: str
|
||||
csrf_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def mint_approval_grant(
|
||||
*,
|
||||
keyset: jws.KeySet,
|
||||
iss: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str,
|
||||
user_code: str,
|
||||
) -> tuple[str, ApprovalGrantClaims]:
|
||||
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
nonce = _random_opaque()
|
||||
csrf_token = _random_opaque()
|
||||
|
||||
payload = {
|
||||
"iss": iss,
|
||||
"subject_email": subject_email,
|
||||
"subject_issuer": subject_issuer,
|
||||
"user_code": user_code,
|
||||
"nonce": nonce,
|
||||
"csrf_token": csrf_token,
|
||||
}
|
||||
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
|
||||
return token, ApprovalGrantClaims(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
user_code=user_code,
|
||||
nonce=nonce,
|
||||
csrf_token=csrf_token,
|
||||
expires_at=exp,
|
||||
)
|
||||
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
user_code=data["user_code"],
|
||||
nonce=data["nonce"],
|
||||
csrf_token=data["csrf_token"],
|
||||
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||
silently drop the cookie.
|
||||
"""
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": value,
|
||||
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": "",
|
||||
"max_age": 0,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def _random_opaque() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Anti-framing headers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_ANTI_FRAMING_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||
}
|
||||
|
||||
|
||||
def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
@ -76,6 +76,7 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
current_app.logger.exception("value_error in request handler")
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
@ -577,3 +577,18 @@ class RateLimiter:
|
||||
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
def seconds_until_available(self, email: str) -> int:
|
||||
"""Seconds until the oldest in-window entry expires, freeing a slot.
|
||||
|
||||
Defensive floor of 1 second. Caller should only invoke this after
|
||||
is_rate_limited() returned True.
|
||||
"""
|
||||
key = self._get_key(email)
|
||||
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
|
||||
if not oldest:
|
||||
return 1
|
||||
_member, score = oldest[0]
|
||||
free_at = int(score) + self.time_window
|
||||
remaining = free_at - int(time.time())
|
||||
return max(remaining, 1)
|
||||
|
||||
108
api/libs/jws.py
Normal file
108
api/libs/jws.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
||||
state envelope, external subject assertion, and approval-grant cookie —
|
||||
all three share one key-set so api ↔ enterprise can verify each other.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
||||
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
||||
|
||||
ACTIVE_KID_V1 = "dify-shared-v1"
|
||||
|
||||
|
||||
class KeySetError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KeySet:
|
||||
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
||||
|
||||
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
||||
if active_kid not in entries:
|
||||
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
||||
if not entries[active_kid]:
|
||||
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
||||
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
||||
self._active_kid = active_kid
|
||||
|
||||
@classmethod
|
||||
def from_shared_secret(cls) -> KeySet:
|
||||
secret = dify_config.SECRET_KEY
|
||||
if not secret:
|
||||
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||
|
||||
@classmethod
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
||||
return cls(entries, active_kid)
|
||||
|
||||
@property
|
||||
def active_kid(self) -> str:
|
||||
return self._active_kid
|
||||
|
||||
def lookup(self, kid: str) -> bytes | None:
|
||||
return self._entries.get(kid)
|
||||
|
||||
|
||||
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
||||
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
||||
if "aud" in payload or "iat" in payload or "exp" in payload:
|
||||
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds must be positive")
|
||||
|
||||
kid = keyset.active_kid
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise KeySetError(f"active kid {kid!r} lookup miss")
|
||||
|
||||
iat = datetime.now(UTC)
|
||||
exp = iat + timedelta(seconds=ttl_seconds)
|
||||
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
||||
return jwt.encode(
|
||||
claims,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"kid": kid, "typ": "JWT"},
|
||||
)
|
||||
|
||||
|
||||
class VerifyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
||||
"""Unknown kid is rejected — never fall back to the active kid, since
|
||||
a past kid value would otherwise be forgeable by anyone who saw it.
|
||||
"""
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode header: {e}") from e
|
||||
kid = header.get("kid")
|
||||
if not kid:
|
||||
raise VerifyError("no kid in header")
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise VerifyError(f"unknown kid {kid!r}")
|
||||
try:
|
||||
return jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience=expected_aud,
|
||||
)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise VerifyError("token expired") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
raise VerifyError("aud mismatch") from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode: {e}") from e
|
||||
650
api/libs/oauth_bearer.py
Normal file
650
api/libs/oauth_bearer.py
Normal file
@ -0,0 +1,650 @@
|
||||
"""OAuth bearer primitives.
|
||||
|
||||
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||
Authenticator + validate_bearer stay untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Literal, ParamSpec, Protocol, TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.rate_limit import enforce_bearer_rate_limit
|
||||
from models import Account, OAuthAccessToken, TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Contract — types, enums, protocols
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SubjectType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
class Scope(StrEnum):
|
||||
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||
|
||||
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
|
||||
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
|
||||
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
|
||||
"""
|
||||
|
||||
FULL = "full"
|
||||
APPS_READ = "apps:read"
|
||||
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
||||
APPS_RUN = "apps:run"
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
"""Subject types a route is willing to accept as caller."""
|
||||
|
||||
USER_ACCOUNT = "user_account"
|
||||
USER_EXT_SSO = "user_ext_sso"
|
||||
|
||||
|
||||
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
||||
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
|
||||
|
||||
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
||||
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthContext:
|
||||
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
|
||||
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
|
||||
|
||||
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||
authenticate time. Per-request mutations write through to Redis via
|
||||
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
client_id: str | None
|
||||
scopes: frozenset[Scope]
|
||||
token_id: uuid.UUID
|
||||
source: str
|
||||
expires_at: datetime | None
|
||||
token_hash: str
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedRow:
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
client_id: str | None
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime | None
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def to_cache(self) -> dict:
|
||||
return {
|
||||
"subject_email": self.subject_email,
|
||||
"subject_issuer": self.subject_issuer,
|
||||
"account_id": str(self.account_id) if self.account_id else None,
|
||||
"client_id": self.client_id,
|
||||
"token_id": str(self.token_id),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"verified_tenants": dict(self.verified_tenants),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_cache(cls, data: dict) -> ResolvedRow:
|
||||
return cls(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
||||
client_id=data.get("client_id"),
|
||||
token_id=uuid.UUID(data["token_id"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
||||
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
|
||||
)
|
||||
|
||||
|
||||
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
|
||||
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
|
||||
|
||||
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
|
||||
on all live deployments (60s TTL → safe to drop one release after rollout).
|
||||
"""
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
out: dict[str, bool] = {}
|
||||
for k, v in raw.items():
|
||||
if isinstance(v, bool):
|
||||
out[k] = v
|
||||
elif v == "ok":
|
||||
out[k] = True
|
||||
elif v == "denied":
|
||||
out[k] = False
|
||||
return out
|
||||
|
||||
|
||||
class Resolver(Protocol):
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[Scope]
|
||||
source: str
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
return token.startswith(self.prefix)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintProfile:
|
||||
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
|
||||
|
||||
Consumers:
|
||||
- ``build_registry`` reads scopes here so the resolve-time TokenKind
|
||||
cannot drift from the mint-time intent.
|
||||
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
|
||||
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
|
||||
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
|
||||
the (subject_type, prefix, scopes) triple a caller intends to mint
|
||||
against this table — a caller that assembles its own scope set
|
||||
from a non-canonical source will fail closed at approve time.
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
prefix: str
|
||||
scopes: frozenset[Scope]
|
||||
|
||||
|
||||
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
|
||||
SubjectType.ACCOUNT: MintProfile(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
prefix="dfoa_",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: MintProfile(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
prefix="dfoe_",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class InvalidBearerError(Exception):
|
||||
"""Token missing, unknown prefix, or no live row."""
|
||||
|
||||
|
||||
class TokenExpiredError(Exception):
|
||||
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenKindRegistry:
|
||||
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
||||
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
||||
prefixes = [k.prefix for k in self._kinds]
|
||||
if len(set(prefixes)) != len(prefixes):
|
||||
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
||||
|
||||
def find(self, token: str) -> TokenKind | None:
|
||||
for k in self._kinds:
|
||||
if k.matches(token):
|
||||
return k
|
||||
return None
|
||||
|
||||
def kinds(self) -> tuple[TokenKind, ...]:
|
||||
return self._kinds
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authenticator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class BearerAuthenticator:
|
||||
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def registry(self) -> TokenKindRegistry:
|
||||
return self._registry
|
||||
|
||||
def authenticate(self, token: str) -> AuthContext:
|
||||
"""Identity + per-token rate limit (single source).
|
||||
|
||||
Both the openapi pipeline (`BearerCheck`) and the decorator
|
||||
(`validate_bearer`) call this — rate-limit fires exactly once per
|
||||
request regardless of which path hosts the route.
|
||||
"""
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearerError("unknown token prefix")
|
||||
token_hash = sha256_hex(token)
|
||||
row = kind.resolver.resolve(token_hash)
|
||||
if row is None:
|
||||
raise InvalidBearerError("token unknown or revoked")
|
||||
enforce_bearer_rate_limit(token_hash)
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=row.account_id,
|
||||
client_id=row.client_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
source=kind.source,
|
||||
expires_at=row.expires_at,
|
||||
token_hash=token_hash,
|
||||
verified_tenants=dict(row.verified_tenants),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth access token resolver (PAT resolver would be a sibling class)
|
||||
# ============================================================================
|
||||
|
||||
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
||||
POSITIVE_TTL_SECONDS = 60
|
||||
NEGATIVE_TTL_SECONDS = 10
|
||||
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
||||
|
||||
ScopeVariant = Literal["account", "external_sso"]
|
||||
|
||||
|
||||
class OAuthAccessTokenResolver:
|
||||
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
||||
sharing DB + cache plumbing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory,
|
||||
redis_client,
|
||||
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._redis = redis_client
|
||||
self._positive_ttl = positive_ttl
|
||||
self._negative_ttl = negative_ttl
|
||||
|
||||
def for_account(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="account")
|
||||
|
||||
def for_external_sso(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="external_sso")
|
||||
|
||||
def _cache_key(self, token_hash: str) -> str:
|
||||
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
|
||||
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
raw = self._redis.get(self._cache_key(token_hash))
|
||||
if raw is None:
|
||||
return None
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return "invalid"
|
||||
try:
|
||||
return ResolvedRow.from_cache(json.loads(text))
|
||||
except (ValueError, KeyError):
|
||||
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||
return None
|
||||
|
||||
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
self._redis.setex(
|
||||
self._cache_key(token_hash),
|
||||
self._positive_ttl,
|
||||
json.dumps(row.to_cache()),
|
||||
)
|
||||
|
||||
def cache_set_negative(self, token_hash: str) -> None:
|
||||
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||
|
||||
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
||||
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||
replays are idempotent.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
if result.rowcount == 1:
|
||||
logger.warning(
|
||||
"audit: %s token_id=%s",
|
||||
AUDIT_OAUTH_EXPIRED,
|
||||
row_id,
|
||||
extra={"audit": True, "token_id": str(row_id)},
|
||||
)
|
||||
self._redis.delete(self._cache_key(token_hash))
|
||||
self.cache_set_negative(token_hash)
|
||||
|
||||
|
||||
class _VariantResolver:
|
||||
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||
self._parent = parent
|
||||
self._variant = variant
|
||||
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||
cached = self._parent.cache_get(token_hash)
|
||||
if cached == "invalid":
|
||||
return None
|
||||
if cached is not None and not isinstance(cached, str):
|
||||
if not self._matches_variant(cached):
|
||||
return None
|
||||
return cached
|
||||
|
||||
# Flask-SQLAlchemy's scoped_session is request-bound and not a
|
||||
# context manager; use it directly.
|
||||
session = self._parent.session_factory()
|
||||
row = self._load_from_db(session, token_hash)
|
||||
if row is None:
|
||||
self._parent.cache_set_negative(token_hash)
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if row.expires_at is not None and row.expires_at <= now:
|
||||
self._parent.hard_expire(session, row.id, token_hash)
|
||||
return None
|
||||
|
||||
if not self._matches_variant_model(row):
|
||||
logger.error(
|
||||
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||
row.id,
|
||||
row.prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
resolved = ResolvedRow(
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
||||
client_id=row.client_id,
|
||||
token_id=uuid.UUID(str(row.id)),
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
self._parent.cache_set_positive(token_hash, resolved)
|
||||
return resolved
|
||||
|
||||
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account
|
||||
return not has_account
|
||||
|
||||
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account and row.prefix == "dfoa_"
|
||||
return (not has_account) and row.prefix == "dfoe_"
|
||||
|
||||
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
||||
return (
|
||||
session.query(OAuthAccessToken)
|
||||
.filter(
|
||||
OAuthAccessToken.token_hash == token_hash,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Layer 0 — workspace membership cache + helper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
|
||||
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
|
||||
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
|
||||
rebuilds via authenticate() and re-runs Layer 0.
|
||||
"""
|
||||
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
raw = redis_client.get(cache_key)
|
||||
if raw is None:
|
||||
return
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
ttl = redis_client.ttl(cache_key)
|
||||
if ttl <= 0:
|
||||
return
|
||||
data.setdefault("verified_tenants", {})[tenant_id] = verdict
|
||||
redis_client.setex(cache_key, ttl, json.dumps(data))
|
||||
|
||||
|
||||
def check_workspace_membership(
|
||||
*,
|
||||
account_id: uuid.UUID | str,
|
||||
tenant_id: str,
|
||||
token_hash: str,
|
||||
cached_verdicts: dict[str, bool],
|
||||
) -> None:
|
||||
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||
|
||||
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
|
||||
inline helper (`require_workspace_member`). Caller is responsible for
|
||||
short-circuiting on EE / SSO subjects before invoking — this function
|
||||
runs the membership + active-status checks unconditionally.
|
||||
"""
|
||||
cached = cached_verdicts.get(tenant_id)
|
||||
if cached is True:
|
||||
return
|
||||
if cached is False:
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
join = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if join is None:
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
|
||||
if status != "active":
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
record_layer0_verdict(token_hash, tenant_id, True)
|
||||
|
||||
|
||||
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
|
||||
|
||||
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
|
||||
(no `tenant_account_joins` row by definition).
|
||||
"""
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=tenant_id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.verified_tenants,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decorator — route-level bearer gate
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_authenticator: BearerAuthenticator | None = None
|
||||
|
||||
|
||||
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
||||
global _authenticator
|
||||
_authenticator = authenticator
|
||||
|
||||
|
||||
def get_authenticator() -> BearerAuthenticator:
|
||||
if _authenticator is None:
|
||||
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
||||
return _authenticator
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
header = req.headers.get("Authorization", "")
|
||||
scheme, _, value = header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not value:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
_DP = ParamSpec("_DP")
|
||||
_DR = TypeVar("_DR")
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
||||
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||
|
||||
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
||||
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
|
||||
and are rejected here as the wrong auth scheme for this surface.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
||||
token = _extract_bearer(request)
|
||||
if token is None:
|
||||
raise Unauthorized("missing bearer token")
|
||||
|
||||
if _authenticator is None:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
|
||||
try:
|
||||
ctx = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||
raise Forbidden("token subject type not accepted here")
|
||||
|
||||
g.auth_ctx = ctx
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||
without the authenticator, so fail fast instead of approving silently.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def require_scope(scope: Scope) -> Callable:
|
||||
"""Route-level scope gate — must run AFTER validate_bearer so that
|
||||
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
|
||||
when the bearer lacks both the requested scope and `Scope.FULL`.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
||||
)
|
||||
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
|
||||
raise Forbidden(f"insufficient_scope: {scope}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Wiring — called once from the app factory
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
return TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix=account.prefix,
|
||||
subject_type=account.subject_type,
|
||||
scopes=account.scopes,
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix=external.prefix,
|
||||
subject_type=external.subject_type,
|
||||
scopes=external.scopes,
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||
registry = build_registry(session_factory, redis_client)
|
||||
auth = BearerAuthenticator(registry)
|
||||
bind_authenticator(auth)
|
||||
return auth
|
||||
140
api/libs/rate_limit.py
Normal file
140
api/libs/rate_limit.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
||||
window Redis ZSET). Apply after auth decorators so scopes can read
|
||||
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
|
||||
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
||||
generic 429.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import g, jsonify, make_response, request, session
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
|
||||
|
||||
class RateLimitScope(StrEnum):
|
||||
IP = "ip"
|
||||
SESSION = "session"
|
||||
ACCOUNT = "account"
|
||||
SUBJECT_EMAIL = "subject_email"
|
||||
TOKEN_ID = "token_id"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimit:
|
||||
limit: int
|
||||
window: timedelta
|
||||
scopes: tuple[RateLimitScope, ...]
|
||||
|
||||
|
||||
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
||||
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
||||
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
||||
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_BEARER_PER_TOKEN = RateLimit(
|
||||
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
|
||||
window=timedelta(minutes=1),
|
||||
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
|
||||
)
|
||||
|
||||
|
||||
def _one_key(scope: RateLimitScope) -> str:
|
||||
match scope:
|
||||
case RateLimitScope.IP:
|
||||
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
||||
case RateLimitScope.SESSION:
|
||||
return f"session:{session.get('_id', 'anon')}"
|
||||
case RateLimitScope.ACCOUNT:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.account_id:
|
||||
return f"account:{ctx.account_id}"
|
||||
return "account:anon"
|
||||
case RateLimitScope.SUBJECT_EMAIL:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.subject_email:
|
||||
return f"subject:{ctx.subject_email}"
|
||||
return "subject:anon"
|
||||
case RateLimitScope.TOKEN_ID:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.token_id:
|
||||
return f"token:{ctx.token_id}"
|
||||
return "token:anon"
|
||||
|
||||
|
||||
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "|".join(_one_key(s) for s in scopes)
|
||||
|
||||
|
||||
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "rl:" + "+".join(s.value for s in scopes)
|
||||
|
||||
|
||||
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||
return RateLimiter(
|
||||
prefix=_limiter_prefix(spec.scopes),
|
||||
max_attempts=spec.limit,
|
||||
time_window=int(spec.window.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Apply after auth decorators that the scopes read from."""
|
||||
limiter = _build_limiter(spec)
|
||||
|
||||
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
key = _composite_key(spec.scopes)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def enforce(spec: RateLimit, *, key: str) -> None:
|
||||
"""Imperative form — caller composes the bucket key to match scope
|
||||
semantics (the key is opaque here).
|
||||
"""
|
||||
limiter = _build_limiter(spec)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
|
||||
|
||||
def enforce_bearer_rate_limit(token_hash: str) -> None:
|
||||
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
|
||||
|
||||
Bucket key = ``token:<sha256_hex>`` so the same token shares one
|
||||
bucket across api replicas (Redis-backed sliding window).
|
||||
"""
|
||||
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
|
||||
key = f"token:{token_hash}"
|
||||
if limiter.is_rate_limited(key):
|
||||
retry_after = limiter.seconds_until_available(key)
|
||||
response = make_response(
|
||||
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
|
||||
429,
|
||||
)
|
||||
response.headers["Retry-After"] = str(retry_after)
|
||||
raise TooManyRequests(response=response)
|
||||
limiter.increment_rate_limit(key)
|
||||
@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
def extract_console_cookie_token(request: Request) -> str | None:
|
||||
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
|
||||
approval routes, which must not fall through to the Authorization header
|
||||
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
return extract_console_cookie_token(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_access_token(request: Request) -> str | None:
|
||||
|
||||
@ -0,0 +1,100 @@
|
||||
"""add oauth_access_tokens table
|
||||
|
||||
Revision ID: d4a5e1f3c9b7
|
||||
Revises: a4f2d8c9b731
|
||||
Create Date: 2026-04-23 22:00:00.000000
|
||||
|
||||
Table stores user-level OAuth bearer tokens minted via the device-flow grant
|
||||
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
|
||||
table not added in this migration.
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d4a5e1f3c9b7"
|
||||
down_revision = "a4f2d8c9b731"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"oauth_access_tokens",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column("subject_email", sa.Text(), nullable=False),
|
||||
sa.Column("subject_issuer", sa.Text(), nullable=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("device_label", sa.Text(), nullable=False),
|
||||
sa.Column("prefix", sa.String(length=8), nullable=False),
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["account_id"],
|
||||
["accounts.id"],
|
||||
name="fk_oauth_access_tokens_account_id",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"idx_oauth_subject_email",
|
||||
"oauth_access_tokens",
|
||||
["subject_email"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_account",
|
||||
"oauth_access_tokens",
|
||||
["account_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_client",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "client_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_token_hash",
|
||||
"oauth_access_tokens",
|
||||
["token_hash"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
# Partial unique index — rotate-in-place keyed on (subject, client, device).
|
||||
# The app always writes a non-NULL subject_issuer (account flow uses a
|
||||
# sentinel, external-SSO uses the verified IdP issuer); without that the
|
||||
# composite key would never collide because Postgres treats NULLs as
|
||||
# distinct in unique indices.
|
||||
op.create_index(
|
||||
"uq_oauth_active_per_device",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
|
||||
op.drop_table("oauth_access_tokens")
|
||||
@ -73,7 +73,7 @@ from .model import (
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -177,6 +177,7 @@ __all__ = [
|
||||
"MessageChain",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OAuthAccessToken",
|
||||
"OperationLog",
|
||||
"PinnedConversation",
|
||||
"Provider",
|
||||
|
||||
@ -84,3 +84,35 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(TypeBase):
|
||||
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
|
||||
subject_issuer = "dify:account" sentinel); account_id NULL +
|
||||
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
|
||||
subject_issuer is non-NULL for all rows the app writes — Postgres
|
||||
treats NULLs as distinct in unique indices, so the partial unique
|
||||
index on (subject_email, subject_issuer, client_id, device_label)
|
||||
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_access_tokens"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
|
||||
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||
)
|
||||
|
||||
@ -1209,6 +1209,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
INSTALLED_APP = "installed-app"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
|
||||
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
|
||||
(token_id stays for audits) so rows accumulate across re-logins, and
|
||||
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import click
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 500
|
||||
|
||||
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_oauth_access_tokens_task():
|
||||
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
|
||||
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
candidates = or_(
|
||||
OAuthAccessToken.revoked_at < cutoff,
|
||||
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
|
||||
)
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
|
||||
if not ids:
|
||||
break
|
||||
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
|
||||
db.session.commit()
|
||||
total += len(ids)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -39,6 +39,8 @@ class AppListParams(BaseModel):
|
||||
name: str | None = None
|
||||
tag_ids: list[str] | None = None
|
||||
is_created_by_me: bool | None = None
|
||||
status: str | None = None
|
||||
openapi_visible: bool = False
|
||||
|
||||
|
||||
class CreateAppParams(BaseModel):
|
||||
@ -75,6 +77,14 @@ class AppService:
|
||||
elif params.mode == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if params.status:
|
||||
filters.append(App.status == params.status)
|
||||
# OpenAPI surface visibility gate. Pushed into the query so
|
||||
# `pagination.total` reflects only apps the openapi caller can
|
||||
# actually reach — post-filtering by enable_api after the page
|
||||
# arrives would make `total` page-dependent.
|
||||
if params.openapi_visible:
|
||||
filters.append(App.enable_api.is_(True))
|
||||
if params.is_created_by_me:
|
||||
filters.append(App.created_by == user_id)
|
||||
if params.name:
|
||||
|
||||
44
api/services/enterprise/app_permitted_service.py
Normal file
44
api/services/enterprise/app_permitted_service.py
Normal file
@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PermittedAppsPage:
|
||||
app_ids: list[str]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
def list_permitted_apps(
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> PermittedAppsPage:
|
||||
try:
|
||||
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
|
||||
page=page, limit=limit, mode=mode, name=name
|
||||
)
|
||||
except EnterpriseAPIError as exc:
|
||||
logger.warning(
|
||||
"permitted_apps EE call failed: status=%s message=%s",
|
||||
getattr(exc, "status_code", None),
|
||||
str(exc),
|
||||
)
|
||||
raise ServiceUnavailable("permitted_apps_unavailable") from exc
|
||||
|
||||
return PermittedAppsPage(
|
||||
app_ids=[row["appId"] for row in body.get("data", [])],
|
||||
total=int(body.get("total", 0)),
|
||||
has_more=bool(body.get("hasMore", False)),
|
||||
)
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -24,10 +25,22 @@ VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppAccessMode(enum.StrEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
PRIVATE_ALL = "private_all"
|
||||
SSO_VERIFIED = "sso_verified"
|
||||
|
||||
|
||||
PERMISSION_CHECK_MODES: frozenset[WebAppAccessMode] = frozenset(
|
||||
{WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||
)
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
|
||||
default="private",
|
||||
description=f"Access mode for the web app. One of: {', '.join(m.value for m in WebAppAccessMode)}",
|
||||
default=WebAppAccessMode.PRIVATE.value,
|
||||
alias="accessMode",
|
||||
)
|
||||
|
||||
@ -108,6 +121,15 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/device-flow/sso-initiate",
|
||||
json={"signed_state": signed_state},
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
@ -219,8 +241,9 @@ class EnterpriseService:
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
if access_mode not in ["public", "private", "private_all"]:
|
||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||
allowed = {WebAppAccessMode.PUBLIC, WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||
if access_mode not in allowed:
|
||||
raise ValueError(f"access_mode must be one of: {', '.join(m.value for m in allowed)}")
|
||||
|
||||
data = {"appId": app_id, "accessMode": access_mode}
|
||||
|
||||
@ -236,6 +259,32 @@ class EnterpriseService:
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def list_externally_accessible_apps(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> dict:
|
||||
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
|
||||
|
||||
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
|
||||
"total": int, "hasMore": bool}``.
|
||||
"""
|
||||
body: dict[str, str | int] = {"page": page, "limit": limit}
|
||||
if mode is not None:
|
||||
body["mode"] = mode
|
||||
if name is not None:
|
||||
body["name"] = name
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/webapp/externally-accessible-apps",
|
||||
json=body,
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
@ -159,7 +159,6 @@ class PluginManagerModel(BaseModel):
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
app_dsl_version: str = ""
|
||||
enable_app_deploy: bool = False
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
enable_marketplace: bool = False
|
||||
@ -234,7 +233,6 @@ class FeatureService:
|
||||
cls._fulfill_system_params_from_env(system_features)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
system_features.enable_app_deploy = True
|
||||
system_features.branding.enabled = True
|
||||
system_features.webapp_auth.enabled = True
|
||||
system_features.enable_change_email = False
|
||||
|
||||
467
api/services/oauth_device_flow.py
Normal file
467
api/services/oauth_device_flow.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Device-flow service layer: Redis state machine, OAuth token mint
|
||||
(DB upsert + plaintext generation), and TTL policy. Specs:
|
||||
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Redis state machine — device_code + user_code ephemeral state
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_DEVICE_CODE_KEY_PREFIX = "device_code:"
|
||||
_USER_CODE_KEY_PREFIX = "user_code:"
|
||||
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
|
||||
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
|
||||
|
||||
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
|
||||
# not both observe APPROVED — only the winner gets the plaintext token,
|
||||
# the loser sees nil and the caller maps that to expired_token.
|
||||
_CONSUME_ON_POLL_LUA = """
|
||||
local raw = redis.call('GET', KEYS[1])
|
||||
if not raw then return nil end
|
||||
local ok, decoded = pcall(cjson.decode, raw)
|
||||
if not ok then return nil end
|
||||
if decoded.status == 'pending' then return nil end
|
||||
if decoded.user_code then
|
||||
redis.call('DEL', ARGV[1] .. decoded.user_code)
|
||||
end
|
||||
redis.call('DEL', KEYS[1])
|
||||
return raw
|
||||
"""
|
||||
|
||||
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
|
||||
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||
USER_CODE_SEGMENT_LEN = 4
|
||||
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
|
||||
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
|
||||
|
||||
|
||||
class DeviceFlowStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
class SlowDownDecision(StrEnum):
|
||||
OK = "ok"
|
||||
SLOW_DOWN = "slow_down"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlowState:
|
||||
"""``minted_token`` is plaintext between approve and the next poll;
|
||||
DEL'd after the poll reads it.
|
||||
"""
|
||||
|
||||
user_code: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
status: DeviceFlowStatus
|
||||
subject_email: str | None = None
|
||||
account_id: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
minted_token: str | None = None
|
||||
token_id: str | None = None
|
||||
created_at: str = ""
|
||||
created_ip: str = ""
|
||||
last_poll_at: str = ""
|
||||
poll_payload: dict | None = field(default=None)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> DeviceFlowState:
|
||||
data = json.loads(raw)
|
||||
if "status" in data:
|
||||
data["status"] = DeviceFlowStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
def _random_device_code() -> str:
|
||||
return "dc_" + secrets.token_urlsafe(24)
|
||||
|
||||
|
||||
def _random_user_code_segment() -> str:
|
||||
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
|
||||
|
||||
|
||||
def _random_user_code() -> str:
|
||||
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||
|
||||
|
||||
class StateNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTransitionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserCodeExhaustedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceFlowRedis:
|
||||
def __init__(self, redis_client) -> None:
|
||||
self._redis = redis_client
|
||||
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
|
||||
|
||||
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
|
||||
device_code = _random_device_code()
|
||||
user_code = self._claim_user_code(device_code)
|
||||
state = DeviceFlowState(
|
||||
user_code=user_code,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
status=DeviceFlowStatus.PENDING,
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
created_ip=created_ip,
|
||||
)
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
state.to_json(),
|
||||
)
|
||||
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
|
||||
|
||||
def _claim_user_code(self, device_code: str) -> str:
|
||||
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
|
||||
user_code = _random_user_code()
|
||||
key = USER_CODE_KEY_FMT.format(code=user_code)
|
||||
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||
if ok:
|
||||
return user_code
|
||||
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
|
||||
|
||||
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||
if not raw_dc:
|
||||
return None
|
||||
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
return None
|
||||
return device_code, state
|
||||
|
||||
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
|
||||
return self._load_state(device_code)
|
||||
|
||||
def _load_state(self, device_code: str) -> DeviceFlowState | None:
|
||||
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if not raw:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state for %s", device_code)
|
||||
return None
|
||||
|
||||
def approve(
|
||||
self,
|
||||
device_code: str,
|
||||
subject_email: str,
|
||||
account_id: str | None,
|
||||
minted_token: str,
|
||||
token_id: str,
|
||||
subject_issuer: str | None = None,
|
||||
poll_payload: dict | None = None,
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot approve {state.status}")
|
||||
|
||||
state.status = DeviceFlowStatus.APPROVED
|
||||
state.subject_email = subject_email
|
||||
state.account_id = account_id
|
||||
state.subject_issuer = subject_issuer
|
||||
state.minted_token = minted_token
|
||||
state.token_id = token_id
|
||||
state.poll_payload = poll_payload
|
||||
|
||||
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
|
||||
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
|
||||
|
||||
def deny(self, device_code: str) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot deny {state.status}")
|
||||
state.status = DeviceFlowStatus.DENIED
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
self._remaining_ttl(device_code, floor=1),
|
||||
state.to_json(),
|
||||
)
|
||||
|
||||
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
|
||||
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
|
||||
single Redis transaction so only one of N concurrent pollers
|
||||
observes the APPROVED state. Losers get None, mapped to
|
||||
expired_token by the caller.
|
||||
"""
|
||||
raw = self._consume_on_poll_script(
|
||||
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
|
||||
args=[_USER_CODE_KEY_PREFIX],
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state on consume %s", device_code)
|
||||
return None
|
||||
|
||||
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||
now = time.time()
|
||||
key = f"device_code:{device_code}:last_poll"
|
||||
prev_raw = self._redis.get(key)
|
||||
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
|
||||
if prev_raw is None:
|
||||
return SlowDownDecision.OK
|
||||
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
|
||||
try:
|
||||
prev = float(prev_s)
|
||||
except ValueError:
|
||||
return SlowDownDecision.OK
|
||||
if now - prev < interval_seconds:
|
||||
return SlowDownDecision.SLOW_DOWN
|
||||
return SlowDownDecision.OK
|
||||
|
||||
def _remaining_ttl(self, device_code: str, floor: int) -> int:
|
||||
"""``max(remaining, floor)`` — guarantees the CLI has at least
|
||||
``floor`` seconds to poll after a near-expiry approve.
|
||||
"""
|
||||
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if ttl is None or ttl < 0:
|
||||
return floor
|
||||
return max(int(ttl), floor)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token mint — generate + upsert
|
||||
# ============================================================================
|
||||
|
||||
|
||||
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
|
||||
PREFIX_OAUTH_ACCOUNT = "dfoa_"
|
||||
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
|
||||
|
||||
# Sentinel issuer for account-flow rows. Postgres' default partial unique
|
||||
# index treats NULLs as distinct, which would let two live `dfoa_` rows
|
||||
# share (email, client, device) and break rotate-in-place. Storing a
|
||||
# non-empty literal makes the composite key collide as intended.
|
||||
ACCOUNT_ISSUER_SENTINEL = "dify:account"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintResult:
|
||||
"""Plaintext token surfaces to the caller once."""
|
||||
|
||||
token: str
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UpsertOutcome:
|
||||
token_id: uuid.UUID
|
||||
rotated: bool
|
||||
old_hash: str | None
|
||||
|
||||
|
||||
def generate_token(prefix: str) -> str:
|
||||
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def mint_oauth_token(
|
||||
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
|
||||
# the wrapper proxies the same execute/commit surface.
|
||||
session: Session | scoped_session,
|
||||
redis_client,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
ttl_days: int,
|
||||
) -> MintResult:
|
||||
"""Live row rotates in place via partial unique index
|
||||
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
|
||||
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
|
||||
deleted so stale AuthContext drops immediately.
|
||||
"""
|
||||
if prefix == PREFIX_OAUTH_ACCOUNT:
|
||||
# Account flow always writes the sentinel — caller may pass None
|
||||
# (for clarity) or the sentinel itself; nothing else is valid.
|
||||
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
|
||||
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
|
||||
subject_issuer = ACCOUNT_ISSUER_SENTINEL
|
||||
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
|
||||
# Defense in depth: enterprise canonicalises + rejects empty,
|
||||
# but a regression there must not yield a NULL composite key here.
|
||||
if not subject_issuer or not subject_issuer.strip():
|
||||
raise ValueError("external-SSO token requires non-empty subject_issuer")
|
||||
else:
|
||||
raise ValueError(f"unknown oauth prefix: {prefix!r}")
|
||||
|
||||
token = generate_token(prefix)
|
||||
new_hash = sha256_hex(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
|
||||
|
||||
outcome = _upsert(
|
||||
session,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
new_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if outcome.rotated and outcome.old_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
|
||||
|
||||
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
|
||||
|
||||
|
||||
def _upsert(
|
||||
session: Session | scoped_session,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
new_hash: str,
|
||||
expires_at: datetime,
|
||||
) -> UpsertOutcome:
|
||||
# Snapshot prior live row's hash for Redis invalidation post-rotate.
|
||||
# subject_issuer is always non-null here (account flow uses sentinel,
|
||||
# external-SSO is validated upstream), so equality matches the index.
|
||||
prior = session.execute(
|
||||
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
|
||||
.where(
|
||||
OAuthAccessToken.subject_email == subject_email,
|
||||
OAuthAccessToken.subject_issuer == subject_issuer,
|
||||
OAuthAccessToken.client_id == client_id,
|
||||
OAuthAccessToken.device_label == device_label,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
old_hash = prior.token_hash if prior else None
|
||||
|
||||
insert_stmt = pg_insert(OAuthAccessToken).values(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
token_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
index_where=OAuthAccessToken.revoked_at.is_(None),
|
||||
set_={
|
||||
"token_hash": insert_stmt.excluded.token_hash,
|
||||
"prefix": insert_stmt.excluded.prefix,
|
||||
"account_id": insert_stmt.excluded.account_id,
|
||||
"expires_at": insert_stmt.excluded.expires_at,
|
||||
"created_at": func.now(),
|
||||
"last_used_at": None,
|
||||
},
|
||||
).returning(OAuthAccessToken.id)
|
||||
row = session.execute(upsert_stmt).first()
|
||||
session.commit()
|
||||
|
||||
if row is None:
|
||||
raise RuntimeError("oauth_token upsert returned no row")
|
||||
token_id = uuid.UUID(str(row.id))
|
||||
return UpsertOutcome(
|
||||
token_id=token_id,
|
||||
rotated=prior is not None,
|
||||
old_hash=old_hash,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TTL policy — days new OAuth tokens live
|
||||
# ============================================================================
|
||||
|
||||
|
||||
DEFAULT_OAUTH_TTL_DAYS = 14
|
||||
MIN_TTL_DAYS = 1
|
||||
MAX_TTL_DAYS = 365
|
||||
|
||||
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
|
||||
|
||||
|
||||
def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
|
||||
is deferred; when it lands it wins over the env (Redis-cached 60s).
|
||||
"""
|
||||
_ = tenant_id
|
||||
|
||||
raw = os.environ.get(_TTL_ENV_VAR)
|
||||
if raw is None:
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"%s=%r is not an int; falling back to %d",
|
||||
_TTL_ENV_VAR,
|
||||
raw,
|
||||
DEFAULT_OAUTH_TTL_DAYS,
|
||||
)
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
if value < MIN_TTL_DAYS:
|
||||
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
|
||||
return MIN_TTL_DAYS
|
||||
if value > MAX_TTL_DAYS:
|
||||
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
|
||||
return MAX_TTL_DAYS
|
||||
return value
|
||||
0
api/services/openapi/__init__.py
Normal file
0
api/services/openapi/__init__.py
Normal file
52
api/services/openapi/license_gate.py
Normal file
52
api/services/openapi/license_gate.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""License gate for the /openapi/v1/permitted-external-apps* surface.
|
||||
|
||||
EE-only. CE deploys (``ENTERPRISE_ENABLED=false``) skip the gate entirely —
|
||||
the EE blueprint chain is what gives CE deploys no callers on this surface
|
||||
in practice, but the explicit short-circuit avoids any test/fixture that
|
||||
flips the surface on without flipping the license.
|
||||
|
||||
Reuses ``FeatureService.get_system_features()`` so the license status
|
||||
travels the same path as the console reads.
|
||||
|
||||
Companion to ``controllers.console.wraps.enterprise_license_required`` —
|
||||
that one is for console (cookie-authed, force-logout 401). This one is
|
||||
for bearer surface (token-authed, 403 ``license_required``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VALID_LICENSE_STATUSES: frozenset[LicenseStatus] = frozenset({LicenseStatus.ACTIVE, LicenseStatus.EXPIRING})
|
||||
|
||||
|
||||
def license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator form. Raises ``Forbidden('license_required')`` when the EE
|
||||
deployment has no valid license. No-op on CE (``ENTERPRISE_ENABLED=false``).
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if dify_config.ENTERPRISE_ENABLED and not _is_license_valid():
|
||||
raise Forbidden(description="license_required")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def _is_license_valid() -> bool:
|
||||
try:
|
||||
features = FeatureService.get_system_features()
|
||||
except Exception:
|
||||
logger.exception("license_gate: FeatureService.get_system_features failed")
|
||||
return False
|
||||
return features.license.status in _VALID_LICENSE_STATUSES
|
||||
47
api/services/openapi/mint_policy.py
Normal file
47
api/services/openapi/mint_policy.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Hard mint policy.
|
||||
|
||||
``validate_mint_policy`` cross-checks a (subject_type, prefix, scopes)
|
||||
triple a caller intends to mint against ``MINTABLE_PROFILES`` —
|
||||
the single source of truth in ``libs.oauth_bearer``.
|
||||
|
||||
The defense-in-depth value: if a future caller assembles ``prefix`` or
|
||||
``scopes`` from a non-canonical source (env, request body, plug-in
|
||||
contribution), the mismatch fails closed at approve time before any
|
||||
row hits the DB. When the caller reads straight from
|
||||
``MINTABLE_PROFILES``, the check is a structural pin — it confirms the
|
||||
table entry is well-formed and the caller picked the right key.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, Scope, SubjectType
|
||||
|
||||
|
||||
class MintPolicyViolation(Exception): # noqa: N818 — spec-defined name, used in BadRequest message
|
||||
"""Raised on a (subject_type, prefix, scopes) mismatch. Callers translate
|
||||
to 400 ``mint_policy_violation``."""
|
||||
|
||||
|
||||
def validate_mint_policy(
|
||||
*,
|
||||
subject_type: SubjectType,
|
||||
prefix: str,
|
||||
scopes: frozenset[Scope],
|
||||
) -> None:
|
||||
"""Raise ``MintPolicyViolation`` when the triple does not match the
|
||||
canonical ``MINTABLE_PROFILES`` entry for ``subject_type``.
|
||||
"""
|
||||
profile = MINTABLE_PROFILES.get(subject_type)
|
||||
if profile is None:
|
||||
raise MintPolicyViolation(f"mint_policy_violation: unknown subject_type={subject_type!r}")
|
||||
|
||||
drift = []
|
||||
if profile.prefix != prefix:
|
||||
drift.append(f"prefix got={prefix!r} expected={profile.prefix!r}")
|
||||
if frozenset(scopes) != profile.scopes:
|
||||
got = sorted(s.value for s in scopes)
|
||||
want = sorted(s.value for s in profile.scopes)
|
||||
drift.append(f"scopes got={got} expected={want}")
|
||||
|
||||
if drift:
|
||||
raise MintPolicyViolation(f"mint_policy_violation: subject_type={subject_type.value} — " + "; ".join(drift))
|
||||
32
api/services/openapi/visibility.py
Normal file
32
api/services/openapi/visibility.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Single-source visibility filter for the /openapi/v1/* surface.
|
||||
|
||||
Keep every openapi-surface app query routed through ``_apply_openapi_gate``;
|
||||
retiring or replacing the gate then becomes a one-line change here.
|
||||
|
||||
The Service API (/v1/* app-key surface) does NOT use this helper — that
|
||||
surface has its own per-request guard (``service_api_disabled``) wired
|
||||
into the legacy ``validate_app_token`` decorator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from models.model import App
|
||||
|
||||
|
||||
def apply_openapi_gate(query: Any) -> Any:
|
||||
"""Filter a SQLAlchemy Select/Query to apps visible on /openapi/v1/*.
|
||||
|
||||
Works with both legacy ``Query.filter`` and 2.0-style ``Select.filter``
|
||||
(alias of ``.where``).
|
||||
"""
|
||||
return query.filter(App.enable_api.is_(True))
|
||||
|
||||
|
||||
def is_openapi_visible(app: App) -> bool:
|
||||
"""Per-row counterpart for code paths that fetch an App by primary key
|
||||
(``session.get`` / ``session.scalar``) and need the same visibility check
|
||||
the query gate would have applied.
|
||||
"""
|
||||
return bool(app.enable_api)
|
||||
@ -15,7 +15,7 @@ from models import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.enterprise.enterprise_service import PERMISSION_CHECK_MODES, EnterpriseService, WebAppAccessMode
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||
|
||||
@ -137,12 +137,8 @@ class WebAppAuthService:
|
||||
"""
|
||||
Check if the app requires permission check based on its access mode.
|
||||
"""
|
||||
modes_requiring_permission_check = [
|
||||
"private",
|
||||
"private_all",
|
||||
]
|
||||
if access_mode:
|
||||
return access_mode in modes_requiring_permission_check
|
||||
return access_mode in PERMISSION_CHECK_MODES
|
||||
|
||||
if not app_code and not app_id:
|
||||
raise ValueError("Either app_code or app_id must be provided.")
|
||||
@ -153,7 +149,7 @@ class WebAppAuthService:
|
||||
raise ValueError("App ID could not be determined from the provided app_code.")
|
||||
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
|
||||
if webapp_settings and webapp_settings.access_mode in PERMISSION_CHECK_MODES:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -166,11 +162,11 @@ class WebAppAuthService:
|
||||
raise ValueError("Either app_code or access_mode must be provided.")
|
||||
|
||||
if access_mode:
|
||||
if access_mode == "public":
|
||||
if access_mode == WebAppAccessMode.PUBLIC:
|
||||
return WebAppAuthType.PUBLIC
|
||||
elif access_mode in ["private", "private_all"]:
|
||||
elif access_mode in PERMISSION_CHECK_MODES:
|
||||
return WebAppAuthType.INTERNAL
|
||||
elif access_mode == "sso_verified":
|
||||
elif access_mode == WebAppAccessMode.SSO_VERIFIED:
|
||||
return WebAppAuthType.EXTERNAL
|
||||
|
||||
if app_code:
|
||||
|
||||
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""Shared fixtures for /openapi/v1/* integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus
|
||||
|
||||
|
||||
def _sha256(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_enterprise(monkeypatch):
|
||||
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
|
||||
EE branch override this with their own monkeypatch in-test."""
|
||||
from configs import dify_config
|
||||
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
|
||||
with flask_app.app_context():
|
||||
tenant = Tenant(name="t1", status="normal")
|
||||
account = Account(email="u@example.com", name="u")
|
||||
db.session.add_all([tenant, account])
|
||||
db.session.commit()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
yield account, tenant, join
|
||||
db.session.delete(join)
|
||||
db.session.delete(account)
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
|
||||
_, tenant, _ = workspace_account
|
||||
with flask_app.app_context():
|
||||
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mint_token(flask_app: Flask):
|
||||
"""Factory fixture; tracks minted rows and deletes them on teardown so
|
||||
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
|
||||
minted: list[OAuthAccessToken] = []
|
||||
|
||||
def _mint(
|
||||
token: str,
|
||||
*,
|
||||
account_id: str | None,
|
||||
prefix: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
) -> OAuthAccessToken:
|
||||
with flask_app.app_context():
|
||||
row = OAuthAccessToken(
|
||||
token_hash=_sha256(token),
|
||||
prefix=prefix,
|
||||
account_id=account_id,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
client_id="difyctl",
|
||||
device_label="test-device",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
minted.append(row)
|
||||
return row
|
||||
|
||||
yield _mint
|
||||
|
||||
with flask_app.app_context():
|
||||
for row in minted:
|
||||
db.session.delete(db.session.merge(row))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account_token(workspace_account, mint_token) -> str:
|
||||
account, _, _ = workspace_account
|
||||
token = "dfoa_" + uuid.uuid4().hex
|
||||
mint_token(
|
||||
token,
|
||||
account_id=account.id,
|
||||
prefix="dfoa_",
|
||||
subject_email=account.email,
|
||||
subject_issuer="dify:account",
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
|
||||
def _flush():
|
||||
with flask_app.app_context():
|
||||
for k in redis_client.keys("auth:*"):
|
||||
redis_client.delete(k)
|
||||
for k in redis_client.keys("rl:*"):
|
||||
redis_client.delete(k)
|
||||
|
||||
_flush()
|
||||
yield
|
||||
_flush()
|
||||
238
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
238
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models import App
|
||||
|
||||
|
||||
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
captured["invoke_from"] = invoke_from
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"conversation_id": "c",
|
||||
"mode": "chat",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "chat"
|
||||
assert captured["mode"] == "chat"
|
||||
assert captured["invoke_from"] == InvokeFrom.OPENAPI
|
||||
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_mode(flask_app: Flask, workspace_account):
|
||||
"""Factory that creates an App row in the workspace_account tenant with
|
||||
a specified mode. Tracks rows for teardown.
|
||||
"""
|
||||
_, tenant, _ = workspace_account
|
||||
created: list[App] = []
|
||||
|
||||
def _make(mode: str) -> App:
|
||||
with flask_app.app_context():
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"a-{mode}",
|
||||
mode=mode,
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db.session.refresh(app)
|
||||
db.session.expunge(app)
|
||||
created.append(app)
|
||||
return app
|
||||
|
||||
yield _make
|
||||
|
||||
with flask_app.app_context():
|
||||
for app in created:
|
||||
db.session.delete(db.session.merge(app))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_required_for_chat" in res.data
|
||||
|
||||
|
||||
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("completion")
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"mode": "completion",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "completion"
|
||||
assert captured["mode"] == "completion"
|
||||
|
||||
|
||||
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("workflow")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_not_supported_for_workflow" in res.data
|
||||
|
||||
|
||||
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("workflow")
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
return {
|
||||
"workflow_run_id": "wfr",
|
||||
"task_id": "t",
|
||||
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.get_json()
|
||||
assert body["mode"] == "workflow"
|
||||
assert body["workflow_run_id"] == "wfr"
|
||||
|
||||
|
||||
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("channel")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"mode_not_runnable" in res.data
|
||||
|
||||
|
||||
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
"""Stub the authenticator to return an AuthContext with empty scopes."""
|
||||
from libs import oauth_bearer
|
||||
|
||||
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
|
||||
|
||||
def _stub_authenticate(self, token: str):
|
||||
ctx = real_authenticate(self, token)
|
||||
from dataclasses import replace
|
||||
|
||||
return replace(ctx, scopes=frozenset())
|
||||
|
||||
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{uuid.uuid4()}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def _stream() -> Generator[str, None, None]:
|
||||
yield 'event: message\ndata: {"x": 1}\n\n'
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate",
|
||||
staticmethod(lambda **kw: _stream()),
|
||||
)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.headers["Content-Type"].startswith("text/event-stream")
|
||||
assert b"event: message" in res.data
|
||||
|
||||
|
||||
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Integration tests for /openapi/v1/apps* read surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from models import App
|
||||
|
||||
|
||||
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_describe_returns_merged_shape(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"]["id"] == app_in_workspace.id
|
||||
assert body["info"]["mode"] == "chat"
|
||||
assert isinstance(body["parameters"], dict)
|
||||
|
||||
|
||||
def test_apps_describe_full_includes_input_schema(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is not None
|
||||
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
|
||||
def test_apps_describe_fields_info_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_parameters_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_input_schema_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_combined(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_unknown_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_describe_fields_extra_param_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_list_returns_pagination_envelope(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["page"] == 1
|
||||
assert body["limit"] == 20
|
||||
assert body["total"] >= 1
|
||||
assert any(d["id"] == app_in_workspace.id for d in body["data"])
|
||||
|
||||
|
||||
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
|
||||
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_apps_list_tag_no_match_returns_empty_data_not_400(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["data"] == []
|
||||
|
||||
|
||||
def test_account_sessions_returns_envelope(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
# canonical envelope shape
|
||||
assert isinstance(body["data"], list)
|
||||
assert "page" in body
|
||||
assert "limit" in body
|
||||
assert "total" in body
|
||||
assert "has_more" in body
|
||||
# the bearer's own minted session must appear
|
||||
assert any(s["prefix"] == "dfoa_" for s in body["data"])
|
||||
# legacy "sessions" key must NOT appear
|
||||
assert "sessions" not in body
|
||||
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Integration tests for the /openapi/v1 bearer auth surface.
|
||||
|
||||
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
|
||||
acceptance/rejection on app-scoped routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
def test_info_accepts_account_bearer_with_apps_read_scope(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["id"] == app_in_workspace.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
|
||||
"""A fresh app under a *different* tenant — caller has no membership row."""
|
||||
with flask_app.app_context():
|
||||
other_tenant = Tenant(name="other", status="normal")
|
||||
db.session.add(other_tenant)
|
||||
db.session.commit()
|
||||
app = App(
|
||||
tenant_id=other_tenant.id,
|
||||
name="b",
|
||||
mode="chat",
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.delete(other_tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_layer0_denies_account_bearer_without_membership(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
) -> None:
|
||||
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
assert res.json.get("message") == "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_layer0_skipped_when_enterprise_enabled(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
|
||||
|
||||
/info uses validate_bearer + require_workspace_member inline (no
|
||||
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
|
||||
gets 200 — gateway is expected to enforce isolation upstream.
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
# Override the conftest autouse default for this test only.
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
|
||||
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json.get("message") != "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_rate_limit_returns_429_after_60_requests(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for i in range(60):
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 200, f"unexpected fail at i={i}"
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
assert r.headers.get("Retry-After"), "Retry-After header missing"
|
||||
assert int(r.headers["Retry-After"]) >= 1
|
||||
body = r.json or {}
|
||||
assert body.get("error") == "rate_limited"
|
||||
assert isinstance(body.get("retry_after_ms"), int)
|
||||
assert body["retry_after_ms"] >= 1000
|
||||
|
||||
|
||||
def test_rate_limit_bucket_shared_across_surfaces(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for _ in range(30):
|
||||
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
|
||||
for _ in range(30):
|
||||
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
@ -291,7 +291,6 @@ class TestFeatureService:
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# Verify enterprise features
|
||||
assert result.enable_app_deploy is True
|
||||
assert result.branding.enabled is True
|
||||
assert result.webapp_auth.enabled is True
|
||||
assert result.enable_change_email is False
|
||||
@ -378,7 +377,6 @@ class TestFeatureService:
|
||||
# Ensure that data required for frontend rendering remains accessible.
|
||||
|
||||
# Branding should match the mock data
|
||||
assert result.enable_app_deploy is True
|
||||
assert result.branding.enabled is True
|
||||
assert result.branding.application_title == "Test Enterprise"
|
||||
assert result.branding.login_page_logo == "https://example.com/logo.png"
|
||||
@ -426,7 +424,6 @@ class TestFeatureService:
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# Verify basic configuration
|
||||
assert result.enable_app_deploy is False
|
||||
assert result.branding.enabled is False
|
||||
assert result.webapp_auth.enabled is False
|
||||
assert result.enable_change_email is True
|
||||
@ -628,7 +625,6 @@ class TestFeatureService:
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# Verify enterprise features are disabled
|
||||
assert result.enable_app_deploy is False
|
||||
assert result.branding.enabled is False
|
||||
assert result.webapp_auth.enabled is False
|
||||
assert result.enable_change_email is True
|
||||
|
||||
@ -1,105 +0,0 @@
|
||||
"""Unit tests for runtime credential inner API."""
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.inner_api.runtime_credentials import (
|
||||
EnterpriseRuntimeModelCredentialsResolve,
|
||||
InnerRuntimeModelCredentialsResolvePayload,
|
||||
)
|
||||
|
||||
|
||||
def test_runtime_model_credentials_payload_accepts_items():
|
||||
payload = InnerRuntimeModelCredentialsResolvePayload.model_validate(
|
||||
{
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [
|
||||
{
|
||||
"credential_id": "credential-1",
|
||||
"provider": "langgenius/openai/openai",
|
||||
"vendor": "openai",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
assert payload.tenant_id == "tenant-1"
|
||||
assert payload.credentials[0].provider == "langgenius/openai/openai"
|
||||
|
||||
|
||||
@patch("controllers.inner_api.runtime_credentials.encrypter.decrypt_token")
|
||||
@patch("controllers.inner_api.runtime_credentials.db")
|
||||
@patch("controllers.inner_api.runtime_credentials.Session")
|
||||
@patch("controllers.inner_api.runtime_credentials.create_plugin_provider_manager")
|
||||
def test_runtime_model_credentials_resolve_returns_decrypted_values(
|
||||
mock_provider_manager_factory,
|
||||
mock_session_cls,
|
||||
mock_db,
|
||||
mock_decrypt_token,
|
||||
app: Flask,
|
||||
):
|
||||
provider_configuration = MagicMock()
|
||||
provider_configuration.provider.provider_credential_schema.credential_form_schemas = []
|
||||
provider_configuration.extract_secret_variables.return_value = ["openai_api_key"]
|
||||
provider_configuration._get_provider_names.return_value = ["langgenius/openai/openai", "openai"]
|
||||
|
||||
provider_configurations = MagicMock()
|
||||
provider_configurations.get.return_value = provider_configuration
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value = provider_configurations
|
||||
mock_provider_manager_factory.return_value = provider_manager
|
||||
|
||||
credential = MagicMock()
|
||||
credential.encrypted_config = '{"openai_api_key":"encrypted","api_base":"https://api.openai.com/v1"}'
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.execute.return_value.scalar_one_or_none.return_value = credential
|
||||
mock_session_cls.return_value = session
|
||||
mock_db.engine = MagicMock()
|
||||
mock_decrypt_token.return_value = "sk-test"
|
||||
|
||||
handler = EnterpriseRuntimeModelCredentialsResolve()
|
||||
unwrapped = inspect.unwrap(handler.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.runtime_credentials.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [
|
||||
{
|
||||
"credential_id": "credential-1",
|
||||
"provider": "langgenius/openai/openai",
|
||||
"vendor": "openai",
|
||||
}
|
||||
],
|
||||
}
|
||||
body, status_code = unwrapped(handler)
|
||||
|
||||
assert status_code == 200
|
||||
assert body["model_credentials"][0]["values"]["openai_api_key"] == "sk-test"
|
||||
assert body["model_credentials"][0]["values"]["api_base"] == "https://api.openai.com/v1"
|
||||
mock_decrypt_token.assert_called_once_with(tenant_id="tenant-1", token="encrypted")
|
||||
|
||||
|
||||
@patch("controllers.inner_api.runtime_credentials.create_plugin_provider_manager")
|
||||
def test_runtime_model_credentials_resolve_rejects_unknown_provider(mock_provider_manager_factory, app: Flask):
|
||||
provider_configurations = MagicMock()
|
||||
provider_configurations.get.return_value = None
|
||||
provider_manager = MagicMock()
|
||||
provider_manager.get_configurations.return_value = provider_configurations
|
||||
mock_provider_manager_factory.return_value = provider_manager
|
||||
|
||||
handler = EnterpriseRuntimeModelCredentialsResolve()
|
||||
unwrapped = inspect.unwrap(handler.post)
|
||||
with app.test_request_context():
|
||||
with patch("controllers.inner_api.runtime_credentials.inner_api_ns") as mock_ns:
|
||||
mock_ns.payload = {
|
||||
"tenant_id": "tenant-1",
|
||||
"credentials": [{"credential_id": "credential-1", "provider": "missing"}],
|
||||
}
|
||||
body, status_code = unwrapped(handler)
|
||||
|
||||
assert status_code == 404
|
||||
assert "provider" in body["message"]
|
||||
@ -0,0 +1,66 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def test_pipeline_is_composed():
|
||||
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
|
||||
|
||||
|
||||
def test_pipeline_step_order():
|
||||
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
|
||||
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
|
||||
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
|
||||
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
|
||||
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
|
||||
steps = OAUTH_BEARER_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], SurfaceCheck)
|
||||
assert isinstance(steps[2], ScopeCheck)
|
||||
assert isinstance(steps[3], AppResolver)
|
||||
assert isinstance(steps[4], WorkspaceMembershipCheck)
|
||||
assert isinstance(steps[5], AppAuthzCheck)
|
||||
assert isinstance(steps[6], CallerMount)
|
||||
|
||||
|
||||
def test_pipeline_surface_check_accepts_account_only():
|
||||
"""Current pipeline serves /apps/<id>/run — account surface only."""
|
||||
surface = OAUTH_BEARER_PIPELINE._steps[1]
|
||||
assert isinstance(surface, SurfaceCheck)
|
||||
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
|
||||
|
||||
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = OAUTH_BEARER_PIPELINE._steps[6]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
@ -0,0 +1,21 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
@ -0,0 +1,61 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
pipeline = Pipeline(FakeStep())
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
@ -0,0 +1,64 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(view_args):
|
||||
req = MagicMock()
|
||||
req.view_args = view_args
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_none_view_args():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -0,0 +1,75 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_private_calls_inner_api(ent):
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="acc1",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "subject_type", "expected"),
|
||||
[
|
||||
("public", SubjectType.ACCOUNT, True),
|
||||
("public", SubjectType.EXTERNAL_SSO, True),
|
||||
("sso_verified", SubjectType.ACCOUNT, True),
|
||||
("sso_verified", SubjectType.EXTERNAL_SSO, True),
|
||||
("private_all", SubjectType.ACCOUNT, True),
|
||||
("private_all", SubjectType.EXTERNAL_SSO, False),
|
||||
("private", SubjectType.EXTERNAL_SSO, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
|
||||
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
|
||||
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
|
||||
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
|
||||
def test_membership_strategy_uses_join_lookup(member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with("acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,67 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
|
||||
|
||||
|
||||
def _ctx(headers):
|
||||
req = MagicMock()
|
||||
req.headers = headers
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=tok_id,
|
||||
source="oauth-account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="hash-1",
|
||||
verified_tenants={},
|
||||
)
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
app = Flask(__name__)
|
||||
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
|
||||
with app.test_request_context():
|
||||
BearerCheck()(ctx)
|
||||
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on `g.auth_ctx`
|
||||
# so the surface gate + downstream handlers don't see two
|
||||
# different identity sources between the decorator + pipeline paths.
|
||||
assert g.auth_ctx is authn
|
||||
assert g.auth_ctx.client_id == "difyctl"
|
||||
@ -0,0 +1,157 @@
|
||||
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(request=MagicMock(), required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
c.cached_verified_tenants = cached_verified_tenants
|
||||
c.token_hash = token_hash
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def step():
|
||||
return WorkspaceMembershipCheck()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id=str(uuid.uuid4()),
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
account_id=None,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": True},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": False},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_record.assert_called_once_with("hash-1", "t1", True)
|
||||
@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(request=MagicMock(), required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
@ -0,0 +1,181 @@
|
||||
"""Surface gate tests.
|
||||
|
||||
The gate has two attachment forms — decorator (`accept_subjects`) and
|
||||
pipeline step (`SurfaceCheck`) — and both must:
|
||||
- 403 on mismatched subject type with a canonical-path hint
|
||||
- emit `openapi.wrong_surface_denied` once with the right payload
|
||||
- pass-through on match
|
||||
- raise RuntimeError (not 403) if g.auth_ctx is missing — that's a
|
||||
wiring bug, not a user-driven failure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
|
||||
def _account_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=uuid.uuid4(),
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _sso_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_surface — shared core
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_surface_passes_when_subject_in_accepted():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
|
||||
|
||||
|
||||
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
assert "wrong_surface" in exc.value.description
|
||||
# canonical-path hint should point at the caller's surface,
|
||||
# not the surface they were rejected from
|
||||
assert "/openapi/v1/apps" in exc.value.description
|
||||
emit.assert_called_once()
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
|
||||
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
|
||||
assert kwargs["client_id"] == "difyctl"
|
||||
assert kwargs["token_id"] is not None
|
||||
|
||||
|
||||
def test_check_surface_rejects_sso_on_account_surface():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _sso_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
|
||||
|
||||
|
||||
def test_check_surface_runtime_error_when_g_auth_ctx_missing():
|
||||
"""Missing g.auth_ctx means the bearer layer didn't run — wiring bug,
|
||||
not a user-driven failure. Surface as RuntimeError (loud) so a future
|
||||
refactor doesn't accidentally let a route skip authentication and
|
||||
return a 403 that looks identical to a legitimate wrong-surface deny.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
with pytest.raises(RuntimeError):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @accept_subjects — decorator form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/account-only")
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def _account_only():
|
||||
return "ok"
|
||||
|
||||
@app.route("/external-only")
|
||||
@accept_subjects(SubjectType.EXTERNAL_SSO)
|
||||
def _external_only():
|
||||
return "ok"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_passes_on_match():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/account-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
# Re-route through the decorated function by reaching for view_function
|
||||
view = app.view_functions["_account_only"]
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_403_on_miss():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/external-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
view = app.view_functions["_external_only"]
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfaceCheck — pipeline step form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _pipeline_ctx() -> Context:
|
||||
req = MagicMock()
|
||||
req.path = "/openapi/v1/apps/<id>/run"
|
||||
return Context(request=req, required_scope=Scope.APPS_RUN)
|
||||
|
||||
|
||||
def test_surface_check_passes_on_match():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
step(_pipeline_ctx()) # no raise
|
||||
|
||||
|
||||
def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
emit.assert_called_once()
|
||||
15
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
15
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
|
||||
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
|
||||
pipeline at import time; mocking the module attribute does not undo
|
||||
that. Patching Pipeline.run on the class is the bypass that actually
|
||||
works.
|
||||
"""
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
140
api/tests/unit_tests/controllers/openapi/test_account.py
Normal file
140
api/tests/unit_tests/controllers/openapi/test_account.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""User-scoped identity + session endpoints under /openapi/v1/account."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.account import (
|
||||
AccountApi,
|
||||
AccountSessionByIdApi,
|
||||
AccountSessionsApi,
|
||||
AccountSessionsSelfApi,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_account_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account" in rules
|
||||
|
||||
|
||||
def test_account_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountApi
|
||||
|
||||
|
||||
def test_account_sessions_self_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/self" in rules
|
||||
|
||||
|
||||
def test_sessions_self_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsSelfApi
|
||||
|
||||
|
||||
def test_account_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_self_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_list_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions" in rules
|
||||
|
||||
|
||||
def test_sessions_list_dispatches_to_sessions_api(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsApi
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_session_by_id_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/<string:session_id>" in rules
|
||||
|
||||
|
||||
def test_session_by_id_dispatches_to_correct_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/<string:session_id>")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionByIdApi
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_subject_match_for_account_filters_by_account_id():
|
||||
"""Account subject scopes queries via account_id."""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
aid = _uuid.uuid4()
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=aid,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
# One predicate, on account_id
|
||||
assert len(clauses) == 1
|
||||
assert "account_id" in str(clauses[0])
|
||||
|
||||
|
||||
def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
"""External SSO subject scopes via (subject_email, subject_issuer)
|
||||
AND account_id IS NULL — so a same-email account row from a
|
||||
federated tenant cannot be revoked through an SSO bearer.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
assert len(clauses) == 3
|
||||
rendered = " ".join(str(c) for c in clauses)
|
||||
assert "subject_email" in rendered
|
||||
assert "subject_issuer" in rendered
|
||||
assert "account_id IS NULL" in rendered
|
||||
@ -0,0 +1,48 @@
|
||||
"""Unit tests for AppDescribeQuery (`?fields=` allow-list)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps import AppDescribeQuery
|
||||
|
||||
|
||||
def test_no_fields_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": ""})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_single_field() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info"})
|
||||
assert q.fields == {"info"}
|
||||
|
||||
|
||||
def test_comma_list() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info,parameters"})
|
||||
assert q.fields == {"info", "parameters"}
|
||||
|
||||
|
||||
def test_whitespace_tolerant() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": " info , input_schema "})
|
||||
assert q.fields == {"info", "input_schema"}
|
||||
|
||||
|
||||
def test_unknown_member_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "garbage"})
|
||||
|
||||
|
||||
def test_unknown_among_known_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info,garbage"})
|
||||
|
||||
|
||||
def test_extra_param_forbidden() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info", "page": "1"})
|
||||
105
api/tests/unit_tests/controllers/openapi/test_app_list_query.py
Normal file
105
api/tests/unit_tests/controllers/openapi/test_app_list_query.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Unit tests for AppListQuery — the /apps query-param validator.
|
||||
|
||||
Runs against the model directly, not the HTTP layer. Pins:
|
||||
- defaults match the plan (page=1, limit=20).
|
||||
- workspace_id is required.
|
||||
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
|
||||
- mode validates against the AppMode enum.
|
||||
- name and tag have length caps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT
|
||||
from controllers.openapi.apps import AppListQuery
|
||||
|
||||
|
||||
def test_defaults():
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
assert q.tag is None
|
||||
|
||||
|
||||
def test_workspace_id_required():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({})
|
||||
|
||||
|
||||
def test_page_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": -1})
|
||||
|
||||
|
||||
def test_page_rejects_non_integer_string():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": "abc"})
|
||||
|
||||
|
||||
def test_limit_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": -1})
|
||||
|
||||
|
||||
def test_limit_caps_at_max_page_limit():
|
||||
# Boundary accepts.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT})
|
||||
assert q.limit == MAX_PAGE_LIMIT
|
||||
|
||||
# Just over rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT + 1})
|
||||
|
||||
|
||||
def test_mode_whitelisted_against_app_mode():
|
||||
# Valid mode passes.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
|
||||
# Invalid mode rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_name_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 200})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 201})
|
||||
|
||||
|
||||
def test_tag_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 100})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 101})
|
||||
|
||||
|
||||
def test_all_fields_accept_valid_values():
|
||||
"""Pin the happy-path acceptance for every field in one place."""
|
||||
q = AppListQuery.model_validate(
|
||||
{
|
||||
"workspace_id": "ws-1",
|
||||
"page": 5,
|
||||
"limit": 50,
|
||||
"mode": "workflow",
|
||||
"name": "search",
|
||||
"tag": "prod",
|
||||
}
|
||||
)
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 5
|
||||
assert q.limit == 50
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "workflow"
|
||||
assert q.name == "search"
|
||||
assert q.tag == "prod"
|
||||
@ -0,0 +1,55 @@
|
||||
"""Unit tests for app payload-rendering helpers — independent of
|
||||
HTTP plumbing or DB. Pin the response shapes that are CLI contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.apps import ( # pyright: ignore[reportPrivateUsage]
|
||||
_EMPTY_PARAMETERS,
|
||||
parameters_payload,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
|
||||
|
||||
def _fake_app(**overrides):
|
||||
base = {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author_name": "alice",
|
||||
"tags": [SimpleNamespace(name="prod")],
|
||||
"updated_at": None,
|
||||
"enable_api": True,
|
||||
"workflow": None,
|
||||
"app_model_config": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_parameters_payload_raises_app_unavailable_when_no_config():
|
||||
with pytest.raises(AppUnavailableError):
|
||||
parameters_payload(_fake_app(mode="chat", app_model_config=None))
|
||||
|
||||
|
||||
def test_empty_parameters_constant_matches_describe_fallback_shape():
|
||||
"""The fallback dict served by /describe when an app has no config
|
||||
must match the spec's stated keys (opening_statement, suggested_questions,
|
||||
user_input_form, file_upload, system_parameters)."""
|
||||
assert set(_EMPTY_PARAMETERS.keys()) == {
|
||||
"opening_statement",
|
||||
"suggested_questions",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"system_parameters",
|
||||
}
|
||||
assert _EMPTY_PARAMETERS["suggested_questions"] == []
|
||||
assert _EMPTY_PARAMETERS["user_input_form"] == []
|
||||
assert _EMPTY_PARAMETERS["opening_statement"] is None
|
||||
assert _EMPTY_PARAMETERS["file_upload"] is None
|
||||
assert _EMPTY_PARAMETERS["system_parameters"] == {}
|
||||
@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.openapi.app_run import (
|
||||
_DISPATCH,
|
||||
AppRunRequest,
|
||||
_unpack_blocking,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_dispatch_covers_runnable_modes():
|
||||
runnable = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW}
|
||||
assert set(_DISPATCH) == runnable
|
||||
|
||||
|
||||
def test_unpack_blocking_passes_through_mapping():
|
||||
assert _unpack_blocking({"a": 1}) == {"a": 1}
|
||||
|
||||
|
||||
def test_unpack_blocking_unwraps_tuple():
|
||||
assert _unpack_blocking(({"a": 1}, 200)) == {"a": 1}
|
||||
|
||||
|
||||
def test_unpack_blocking_rejects_non_mapping():
|
||||
with pytest.raises(InternalServerError):
|
||||
_unpack_blocking("not a mapping")
|
||||
|
||||
|
||||
def test_app_run_request_strips_blank_conversation_id():
|
||||
payload = AppRunRequest(inputs={}, conversation_id=" ")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_app_run_request_rejects_invalid_uuid_conversation_id():
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="conversation_id must be a valid UUID"):
|
||||
AppRunRequest(inputs={}, conversation_id="not-a-uuid")
|
||||
|
||||
|
||||
def test_app_run_request_accepts_valid_uuid_conversation_id():
|
||||
import uuid as _uuid
|
||||
|
||||
cid = str(_uuid.uuid4())
|
||||
payload = AppRunRequest(inputs={}, conversation_id=cid)
|
||||
assert payload.conversation_id == cid
|
||||
@ -0,0 +1,53 @@
|
||||
"""Unit tests for PermittedExternalAppsListQuery — the
|
||||
/permitted-external-apps query validator.
|
||||
|
||||
Strict ConfigDict(extra='forbid'): cross-tenant tag/workspace_id are
|
||||
unresolvable, so the model must reject them as 422 instead of silently
|
||||
dropping them. Mode/name/page/limit have the same shape as AppListQuery.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps_permitted_external import PermittedExternalAppsListQuery
|
||||
|
||||
|
||||
def test_query_defaults_match_apps_list():
|
||||
q = PermittedExternalAppsListQuery.model_validate({})
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
|
||||
|
||||
def test_query_rejects_workspace_id():
|
||||
"""workspace_id is meaningless for /permitted-external-apps (cross-tenant);
|
||||
rejecting it forces CLI authors to drop the param rather than send it
|
||||
silently."""
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
|
||||
|
||||
def test_query_rejects_tag():
|
||||
"""Tags are tenant-scoped; cross-tenant tag resolution is undefined."""
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"tag": "prod"})
|
||||
|
||||
|
||||
def test_query_validates_mode_against_app_mode():
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_query_clamps_limit_at_max():
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"limit": 500})
|
||||
|
||||
|
||||
def test_query_accepts_valid_mode():
|
||||
"""Pin the happy path: AppMode values pass."""
|
||||
q = PermittedExternalAppsListQuery.model_validate({"mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
@ -0,0 +1,26 @@
|
||||
import logging
|
||||
|
||||
from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run
|
||||
|
||||
|
||||
def test_event_constant():
|
||||
assert EVENT_APP_RUN_OPENAPI == "app.run.openapi"
|
||||
|
||||
|
||||
def test_emit_app_run_logs_with_audit_extra(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"):
|
||||
emit_app_run(
|
||||
app_id="app1",
|
||||
tenant_id="t1",
|
||||
caller_kind="account",
|
||||
mode="chat",
|
||||
surface="apps",
|
||||
)
|
||||
record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message)
|
||||
assert record.audit is True
|
||||
assert record.event == EVENT_APP_RUN_OPENAPI
|
||||
assert record.app_id == "app1"
|
||||
assert record.tenant_id == "t1"
|
||||
assert record.caller_kind == "account"
|
||||
assert record.mode == "chat"
|
||||
assert record.surface == "apps"
|
||||
127
api/tests/unit_tests/controllers/openapi/test_cors.py
Normal file
127
api/tests/unit_tests/controllers/openapi/test_cors.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""CORS posture for /openapi/v1/* — default empty allowlist (same-origin),
|
||||
expandable via OPENAPI_CORS_ALLOW_ORIGINS. Cross-origin requests from
|
||||
disallowed origins do not receive the Access-Control-Allow-Origin
|
||||
header, which the browser then blocks.
|
||||
|
||||
Tests use a fresh Blueprint + Flask-CORS per case because the production
|
||||
blueprint is a module-level singleton and can't be reconfigured once
|
||||
registered.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask.views import MethodView
|
||||
from flask_cors import CORS
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_blueprints import OPENAPI_HEADERS, OPENAPI_MAX_AGE_SECONDS
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _make_app(allowed_origins: list[str], blueprint_name: str) -> Flask:
|
||||
"""Build a Flask app with a fresh openapi-style blueprint mirroring
|
||||
production CORS settings, parameterised on the origin allowlist.
|
||||
"""
|
||||
bp = Blueprint(blueprint_name, __name__, url_prefix="/openapi/v1")
|
||||
api = ExternalApi(bp, version="1.0", title="OpenAPI Test", description="")
|
||||
|
||||
@api.route("/_health")
|
||||
class _Health(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
|
||||
CORS(
|
||||
bp,
|
||||
resources={r"/*": {"origins": allowed_origins}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=["X-Version"],
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_default_openapi_cors_allowlist_is_empty():
|
||||
"""Default config admits no cross-origin until operator opts in."""
|
||||
assert dify_config.OPENAPI_CORS_ALLOW_ORIGINS == []
|
||||
|
||||
|
||||
def test_preflight_allowed_origin_returns_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t1")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "https://app.example.com"
|
||||
assert response.headers.get("Access-Control-Max-Age") == str(OPENAPI_MAX_AGE_SECONDS)
|
||||
|
||||
|
||||
def test_preflight_disallowed_origin_omits_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t2")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://attacker.example",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
# flask-cors omits Allow-Origin for disallowed origins; browser blocks.
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_preflight_with_default_empty_allowlist_omits_cors_headers():
|
||||
app = _make_app([], "openapi_t3")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_same_origin_request_succeeds_without_origin_header():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t4")
|
||||
client = app.test_client()
|
||||
# Browsers don't send Origin on same-origin GETs.
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_authorization_header_is_in_allow_headers():
|
||||
"""Bearer-authed routes need Authorization in the preflight response."""
|
||||
app = _make_app(["https://app.example.com"], "openapi_t5")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
"Access-Control-Request-Headers": "Authorization",
|
||||
},
|
||||
)
|
||||
|
||||
allow_headers = response.headers.get("Access-Control-Allow-Headers", "").lower()
|
||||
assert "authorization" in allow_headers
|
||||
@ -0,0 +1,52 @@
|
||||
"""Account-branch device-flow approve/deny under /openapi/v1."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import DeviceApproveApi, DeviceDenyApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_approve_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve" in rules
|
||||
|
||||
|
||||
def test_deny_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/deny" in rules
|
||||
|
||||
|
||||
def test_approve_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceApproveApi
|
||||
|
||||
|
||||
def test_deny_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceDenyApi
|
||||
|
||||
|
||||
def test_approve_and_deny_methods(openapi_app: Flask):
|
||||
approve = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
deny = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert "POST" in approve.methods
|
||||
assert "POST" in deny.methods
|
||||
47
api/tests/unit_tests/controllers/openapi/test_device_code.py
Normal file
47
api/tests/unit_tests/controllers/openapi/test_device_code.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""POST /openapi/v1/oauth/device/code is the canonical RFC 8628 device
|
||||
authorization endpoint.
|
||||
|
||||
Tests verify URL routing without invoking the handler — invoking would
|
||||
require Redis, which the unit-test runtime does not initialise.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceCodeApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/code" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
|
||||
|
||||
|
||||
def test_route_accepts_post(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert "POST" in rule.methods
|
||||
|
||||
|
||||
def test_known_client_ids_default_includes_difyctl():
|
||||
from configs import dify_config
|
||||
|
||||
assert "difyctl" in dify_config.OPENAPI_KNOWN_CLIENT_IDS
|
||||
@ -0,0 +1,36 @@
|
||||
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceLookupApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/lookup" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
|
||||
|
||||
|
||||
def test_route_accepts_get(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert "GET" in rule.methods
|
||||
105
api/tests/unit_tests/controllers/openapi/test_device_sso.py
Normal file
105
api/tests/unit_tests/controllers/openapi/test_device_sso.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device_sso import (
|
||||
_email_belongs_to_dify_account,
|
||||
approval_context,
|
||||
approve_external,
|
||||
sso_complete,
|
||||
sso_initiate,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_sso_initiate_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-initiate" in rules
|
||||
|
||||
|
||||
def test_sso_complete_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-complete" in rules
|
||||
|
||||
|
||||
def test_approval_context_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approval-context" in rules
|
||||
|
||||
|
||||
def test_approve_external_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve-external" in rules
|
||||
|
||||
|
||||
def test_sso_initiate_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-initiate")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_initiate
|
||||
|
||||
|
||||
def test_sso_complete_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-complete")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_complete
|
||||
|
||||
|
||||
def test_approval_context_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approval-context")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approval_context
|
||||
|
||||
|
||||
def test_approve_external_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve-external")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approve_external
|
||||
|
||||
|
||||
def test_sso_complete_idp_callback_url_uses_canonical_path():
|
||||
"""sso_initiate hardcodes the IdP callback URL — must point at the
|
||||
canonical /openapi/v1/ path so IdP-side ACS configuration matches.
|
||||
"""
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("email", "row", "expected"),
|
||||
[
|
||||
("alice@example.com", "acc1", True),
|
||||
("alice@example.com", None, False),
|
||||
("Alice@Example.COM", "acc1", True), # case-insensitive lookup
|
||||
(" alice@example.com ", "acc1", True), # surrounding whitespace stripped
|
||||
("", "acc1", False),
|
||||
(" ", "acc1", False),
|
||||
("", None, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.oauth_device_sso.db")
|
||||
def test_email_belongs_to_dify_account(db_mock, email, row, expected):
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalar_one_or_none.return_value = row
|
||||
db_mock.session.execute.return_value = exec_result
|
||||
assert _email_belongs_to_dify_account(email) is expected
|
||||
if email.strip():
|
||||
db_mock.session.execute.assert_called_once()
|
||||
else:
|
||||
db_mock.session.execute.assert_not_called()
|
||||
@ -0,0 +1,31 @@
|
||||
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceTokenApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/token" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi
|
||||
33
api/tests/unit_tests/controllers/openapi/test_health.py
Normal file
33
api/tests/unit_tests/controllers/openapi/test_health.py
Normal file
@ -0,0 +1,33 @@
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_health_returns_ok(app: Flask):
|
||||
client = app.test_client()
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_health_path_is_under_openapi_v1_prefix(app: Flask):
|
||||
client = app.test_client()
|
||||
assert client.get("/_health").status_code == 404
|
||||
assert client.get("/v1/_health").status_code == 404
|
||||
assert client.get("/openapi/v1/_health").status_code == 200
|
||||
182
api/tests/unit_tests/controllers/openapi/test_input_schema.py
Normal file
182
api/tests/unit_tests/controllers/openapi/test_input_schema.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Unit tests for input_schema derivation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi._input_schema import _form_to_jsonschema
|
||||
|
||||
|
||||
def _wrap(component: dict) -> list[dict]:
|
||||
"""user_input_form rows are single-key dicts: {"text-input": {...}}."""
|
||||
return [component]
|
||||
|
||||
|
||||
def test_text_input_required() -> None:
|
||||
form = _wrap({"text-input": {"variable": "industry", "label": "Industry", "required": True, "max_length": 200}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"industry": {"type": "string", "title": "Industry", "maxLength": 200}}
|
||||
assert required == ["industry"]
|
||||
|
||||
|
||||
def test_paragraph_optional() -> None:
|
||||
form = _wrap({"paragraph": {"variable": "context", "label": "Context", "required": False, "max_length": 4000}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["context"] == {"type": "string", "title": "Context", "maxLength": 4000}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_select_enum() -> None:
|
||||
form = _wrap(
|
||||
{
|
||||
"select": {
|
||||
"variable": "tier",
|
||||
"label": "Tier",
|
||||
"required": True,
|
||||
"options": ["free", "pro", "enterprise"],
|
||||
}
|
||||
}
|
||||
)
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"tier": {"type": "string", "title": "Tier", "enum": ["free", "pro", "enterprise"]}}
|
||||
assert required == ["tier"]
|
||||
|
||||
|
||||
def test_number() -> None:
|
||||
form = _wrap({"number": {"variable": "count", "label": "Count", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["count"] == {"type": "number", "title": "Count"}
|
||||
|
||||
|
||||
def test_file() -> None:
|
||||
form = _wrap({"file": {"variable": "doc", "label": "Doc", "required": True}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["doc"]["type"] == "object"
|
||||
assert "title" in props["doc"]
|
||||
assert required == ["doc"]
|
||||
|
||||
|
||||
def test_file_list() -> None:
|
||||
form = _wrap({"file-list": {"variable": "attachments", "label": "Attachments", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["attachments"]["type"] == "array"
|
||||
assert props["attachments"]["items"]["type"] == "object"
|
||||
|
||||
|
||||
def test_unknown_type_skipped() -> None:
|
||||
"""Forward-compat: unknown variable types are skipped, not 500'd."""
|
||||
form = _wrap({"future-type": {"variable": "x", "label": "X", "required": False}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_required_order_preserved() -> None:
|
||||
form = [
|
||||
{"text-input": {"variable": "a", "label": "A", "required": True}},
|
||||
{"text-input": {"variable": "b", "label": "B", "required": False}},
|
||||
{"text-input": {"variable": "c", "label": "C", "required": True}},
|
||||
]
|
||||
_props, required = _form_to_jsonschema(form)
|
||||
assert required == ["a", "c"]
|
||||
|
||||
|
||||
def test_max_length_omitted_when_zero() -> None:
|
||||
form = _wrap({"text-input": {"variable": "x", "label": "X", "required": False, "max_length": 0}})
|
||||
props, _ = _form_to_jsonschema(form)
|
||||
assert "maxLength" not in props["x"]
|
||||
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _stub_app(mode: AppMode, *, form: list[dict] | None = None, has_workflow: bool | None = None):
|
||||
"""Returns a MagicMock whose .mode + workflow / app_model_config branch is wired up."""
|
||||
app = MagicMock()
|
||||
app.mode = mode
|
||||
if mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
|
||||
if has_workflow is False:
|
||||
app.workflow = None
|
||||
else:
|
||||
app.workflow = MagicMock()
|
||||
app.workflow.user_input_form.return_value = form or []
|
||||
app.workflow.features_dict = {}
|
||||
else:
|
||||
if has_workflow is False:
|
||||
app.app_model_config = None
|
||||
else:
|
||||
app.app_model_config = MagicMock()
|
||||
app.app_model_config.to_dict.return_value = {"user_input_form": form or []}
|
||||
return app
|
||||
|
||||
|
||||
def test_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.CHAT, form=[{"text-input": {"variable": "x", "label": "X", "required": True}}])
|
||||
schema = build_input_schema(app)
|
||||
assert schema["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
assert "query" in schema["properties"]
|
||||
assert schema["properties"]["query"]["type"] == "string"
|
||||
assert schema["properties"]["query"]["minLength"] == 1
|
||||
assert "query" in schema["required"]
|
||||
assert "inputs" in schema["required"]
|
||||
assert schema["properties"]["inputs"]["additionalProperties"] is False
|
||||
|
||||
|
||||
def test_agent_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.AGENT_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_advanced_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.ADVANCED_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_workflow_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_completion_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.COMPLETION, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_inputs_required_driven_by_form() -> None:
|
||||
app = _stub_app(
|
||||
AppMode.CHAT,
|
||||
form=[
|
||||
{"text-input": {"variable": "industry", "label": "Industry", "required": True}},
|
||||
{"text-input": {"variable": "context", "label": "Context", "required": False}},
|
||||
],
|
||||
)
|
||||
schema = build_input_schema(app)
|
||||
assert schema["properties"]["inputs"]["required"] == ["industry"]
|
||||
|
||||
|
||||
def test_misconfigured_chat_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.CHAT, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_misconfigured_workflow_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_empty_input_schema_sentinel_shape() -> None:
|
||||
assert EMPTY_INPUT_SCHEMA["type"] == "object"
|
||||
assert EMPTY_INPUT_SCHEMA["properties"] == {}
|
||||
assert EMPTY_INPUT_SCHEMA["required"] == []
|
||||
31
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
31
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
@ -0,0 +1,31 @@
|
||||
from controllers.openapi._models import MessageMetadata, UsageInfo
|
||||
|
||||
|
||||
def test_usage_info_defaults_zero():
|
||||
u = UsageInfo()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
|
||||
|
||||
def test_message_metadata_accepts_partial():
|
||||
m = MessageMetadata(usage=UsageInfo(total_tokens=10))
|
||||
assert m.usage.total_tokens == 10
|
||||
assert m.retriever_resources == []
|
||||
|
||||
|
||||
def test_describe_response_all_blocks_optional() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
payload = AppDescribeResponse().model_dump(mode="json", exclude_none=False)
|
||||
assert payload == {"info": None, "parameters": None, "input_schema": None}
|
||||
|
||||
|
||||
def test_describe_response_input_schema_field() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
schema = {"$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object"}
|
||||
payload = AppDescribeResponse(input_schema=schema).model_dump(mode="json", exclude_none=False)
|
||||
assert payload["input_schema"] == schema
|
||||
assert payload["info"] is None
|
||||
assert payload["parameters"] is None
|
||||
@ -0,0 +1,140 @@
|
||||
"""Unit tests for PaginationEnvelope generic Pydantic model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.openapi._models import PaginationEnvelope
|
||||
|
||||
|
||||
class _Row(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
def test_envelope_basic_fields():
|
||||
env = PaginationEnvelope[_Row](page=1, limit=20, total=42, has_more=True, data=[_Row(id="a", name="A")])
|
||||
dumped = env.model_dump(mode="json")
|
||||
assert dumped == {
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total": 42,
|
||||
"has_more": True,
|
||||
"data": [{"id": "a", "name": "A"}],
|
||||
}
|
||||
|
||||
|
||||
def test_envelope_empty_data_no_more():
|
||||
env = PaginationEnvelope[_Row](page=1, limit=20, total=0, has_more=False, data=[])
|
||||
assert env.model_dump(mode="json")["data"] == []
|
||||
assert env.model_dump(mode="json")["has_more"] is False
|
||||
|
||||
|
||||
def test_envelope_has_more_true_when_total_exceeds_page_window():
|
||||
env = PaginationEnvelope[_Row].build(page=1, limit=20, total=42, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is True
|
||||
|
||||
|
||||
def test_envelope_has_more_false_when_total_within_page_window():
|
||||
env = PaginationEnvelope[_Row].build(page=2, limit=20, total=22, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is False
|
||||
|
||||
|
||||
def test_envelope_has_more_false_for_last_page():
|
||||
env = PaginationEnvelope[_Row].build(page=3, limit=20, total=42, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is False
|
||||
|
||||
|
||||
def test_max_page_limit_is_200():
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT
|
||||
|
||||
assert MAX_PAGE_LIMIT == 200
|
||||
|
||||
|
||||
def test_envelope_uses_pep695_generics():
|
||||
"""Verify the class uses PEP 695 native generic syntax (not legacy Generic[T])."""
|
||||
from controllers.openapi._models import PaginationEnvelope
|
||||
|
||||
# PEP 695 syntax populates __type_params__; the legacy Generic[T] form does not.
|
||||
assert PaginationEnvelope.__type_params__, "expected PEP 695 native generic syntax"
|
||||
|
||||
fields = PaginationEnvelope.model_fields
|
||||
assert {"page", "limit", "total", "has_more", "data"} <= set(fields)
|
||||
|
||||
|
||||
def test_app_info_response_dump_matches_spec():
|
||||
from controllers.openapi._models import AppInfoResponse
|
||||
|
||||
obj = AppInfoResponse(
|
||||
id="app1",
|
||||
name="X",
|
||||
description="d",
|
||||
mode="chat",
|
||||
author="alice",
|
||||
tags=[{"name": "prod"}],
|
||||
)
|
||||
assert obj.model_dump(mode="json") == {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author": "alice",
|
||||
"tags": [{"name": "prod"}],
|
||||
}
|
||||
|
||||
|
||||
def test_app_describe_response_nests_info_and_parameters():
|
||||
from controllers.openapi._models import AppDescribeInfo, AppDescribeResponse
|
||||
|
||||
info = AppDescribeInfo(
|
||||
id="app1",
|
||||
name="X",
|
||||
mode="chat",
|
||||
description=None,
|
||||
tags=[],
|
||||
author=None,
|
||||
updated_at="2026-05-05T00:00:00+00:00",
|
||||
service_api_enabled=True,
|
||||
)
|
||||
obj = AppDescribeResponse(info=info, parameters={"opening_statement": None})
|
||||
dumped = obj.model_dump(mode="json")
|
||||
assert dumped["info"]["service_api_enabled"] is True
|
||||
assert dumped["parameters"]["opening_statement"] is None
|
||||
|
||||
|
||||
def test_response_models_dump_per_mode():
|
||||
from controllers.openapi._models import (
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunData,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
|
||||
chat = ChatMessageResponse(
|
||||
event="message",
|
||||
task_id="t1",
|
||||
id="m1",
|
||||
message_id="m1",
|
||||
conversation_id="c1",
|
||||
mode="chat",
|
||||
answer="hi",
|
||||
created_at=0,
|
||||
)
|
||||
assert chat.model_dump(mode="json")["mode"] == "chat"
|
||||
wf = WorkflowRunResponse(
|
||||
workflow_run_id="r1",
|
||||
task_id="t1",
|
||||
data=WorkflowRunData(id="r1", workflow_id="w1", status="succeeded"),
|
||||
)
|
||||
assert wf.model_dump(mode="json")["data"]["status"] == "succeeded"
|
||||
assert wf.model_dump(mode="json")["mode"] == "workflow"
|
||||
comp = CompletionMessageResponse(
|
||||
event="message",
|
||||
task_id="t2",
|
||||
id="m2",
|
||||
message_id="m2",
|
||||
mode="completion",
|
||||
answer="ok",
|
||||
created_at=0,
|
||||
)
|
||||
assert comp.model_dump(mode="json")["mode"] == "completion"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user