mirror of
https://github.com/langgenius/dify.git
synced 2026-05-18 16:06:36 +08:00
Compare commits
6 Commits
feat/cli
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| 688424429d | |||
| 9494e4f267 | |||
| 969760364d | |||
| ceabfeb3a7 | |||
| c407f40e0d | |||
| 28818f2e2a |
@ -1,15 +0,0 @@
|
||||
**/node_modules
|
||||
**/.pnpm-store
|
||||
**/dist
|
||||
**/.next
|
||||
**/.turbo
|
||||
**/.cache
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/.mypy_cache
|
||||
**/.ruff_cache
|
||||
.git
|
||||
.github
|
||||
*.md
|
||||
!web/README.md
|
||||
!api/README.md
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,10 +18,6 @@
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# CLI
|
||||
/cli/ @langgenius/maintainers
|
||||
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
|
||||
5
.github/actions/setup-web/action.yml
vendored
5
.github/actions/setup-web/action.yml
vendored
@ -1,8 +1,13 @@
|
||||
name: Setup Web Environment
|
||||
description: Set up Node.js, Vite+, pnpm, and web dependencies
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5
|
||||
with:
|
||||
run_install: false
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
|
||||
63
.github/workflows/cli-docker-build.yml
vendored
63
.github/workflows/cli-docker-build.yml
vendored
@ -1,63 +0,0 @@
|
||||
name: CLI Docker Build (dev)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- "cli/**"
|
||||
- "packages/tsconfig/**"
|
||||
- "pnpm-lock.yaml"
|
||||
- "pnpm-workspace.yaml"
|
||||
merge_group:
|
||||
branches:
|
||||
- "main"
|
||||
types: [checks_requested]
|
||||
|
||||
concurrency:
|
||||
group: cli-docker-build-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build CLI dev image
|
||||
if: github.event_name == 'merge_group' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build CLI Dockerfile.dev
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: "{{defaultContext}}"
|
||||
file: "cli/Dockerfile.dev"
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
||||
build-fork:
|
||||
name: Build CLI dev image (fork)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build CLI Dockerfile.dev
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: "."
|
||||
file: "cli/Dockerfile.dev"
|
||||
platforms: linux/amd64
|
||||
131
.github/workflows/cli-release.yml
vendored
131
.github/workflows/cli-release.yml
vendored
@ -1,131 +0,0 @@
|
||||
name: CLI Release
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_release_tag:
|
||||
description: "dify release tag to attach cli artifacts to (e.g. 1.14.0). Bare semver — dify tags are NOT v-prefixed."
|
||||
type: string
|
||||
required: true
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.event.release.tag_name || inputs.dify_release_tag }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: >-
|
||||
github.repository == 'langgenius/dify' &&
|
||||
(github.event_name == 'workflow_dispatch' ||
|
||||
(vars.CLI_AUTO_RELEASE == 'true' && !github.event.release.prerelease))
|
||||
env:
|
||||
DIFY_TAG: ${{ github.event.release.tag_name || inputs.dify_release_tag }}
|
||||
permissions:
|
||||
contents: write
|
||||
id-token: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Node registry auth
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Bump guard (auto-path only)
|
||||
if: github.event_name == 'release'
|
||||
run: scripts/release-bump-guard.sh
|
||||
env:
|
||||
NEW_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
NEW_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
NEW_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
|
||||
- name: Verify target dify release exists
|
||||
run: gh release view "$DIFY_TAG" --repo langgenius/dify > /dev/null
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build cli
|
||||
run: |
|
||||
DIFYCTL_VERSION="${{ steps.manifest.outputs.version }}" \
|
||||
DIFYCTL_CHANNEL="${{ steps.manifest.outputs.channel }}" \
|
||||
DIFYCTL_MIN_DIFY="${{ steps.manifest.outputs.minDify }}" \
|
||||
DIFYCTL_MAX_DIFY="${{ steps.manifest.outputs.maxDify }}" \
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build
|
||||
|
||||
- name: Pack tarballs
|
||||
run: pnpm pack:tarballs
|
||||
|
||||
- name: Publish to npm (idempotent)
|
||||
run: scripts/release-npm-publish.sh
|
||||
env:
|
||||
CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
NEW_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
run: scripts/release-write-checksums.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
|
||||
- name: Install cosign
|
||||
uses: sigstore/cosign-installer@3454372f43399081ed03b604cb2d021dabca52bb # v3.8.2
|
||||
|
||||
- name: Keyless-sign tarballs + checksum file (Sigstore)
|
||||
run: scripts/release-cosign-sign.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
COSIGN_EXPERIMENTAL: '1'
|
||||
|
||||
- name: Snapshot tarballs + checksum + signatures as workflow artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
|
||||
with:
|
||||
name: difyctl-${{ steps.manifest.outputs.version }}-${{ env.DIFY_TAG }}
|
||||
path: |
|
||||
cli/dist/difyctl-v*.tar.xz
|
||||
cli/dist/difyctl-v*-checksums.txt
|
||||
cli/dist/difyctl-v*.sig
|
||||
cli/dist/difyctl-v*.pem
|
||||
retention-days: 90
|
||||
if-no-files-found: error
|
||||
|
||||
- name: Upload tarballs + checksum + signatures to dify GH release (idempotent)
|
||||
run: scripts/release-upload-tarballs.sh
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
57
.github/workflows/cli-smoke.yml
vendored
57
.github/workflows/cli-smoke.yml
vendored
@ -1,57 +0,0 @@
|
||||
name: CLI Smoke (live dify)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_version:
|
||||
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||
type: string
|
||||
required: true
|
||||
cli_ref:
|
||||
description: "Git ref to build the cli from (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
jobs:
|
||||
smoke:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Bring up dify
|
||||
env:
|
||||
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||
run: |
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||
docker compose up -d api worker web db redis
|
||||
for i in $(seq 1 60); do
|
||||
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||
echo "dify api ready after ${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run smoke against live dify
|
||||
working-directory: ./cli
|
||||
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||
|
||||
- name: Dump dify logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
cd docker
|
||||
docker compose logs api worker web --tail=200
|
||||
46
.github/workflows/cli-tests.yml
vendored
46
.github/workflows/cli-tests.yml
vendored
@ -1,46 +0,0 @@
|
||||
name: CLI Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
@ -64,18 +63,6 @@ jobs:
|
||||
- 'docker/generate_docker_compose'
|
||||
- 'docker/ssrf_proxy/**'
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
cli:
|
||||
- 'cli/**'
|
||||
- 'packages/tsconfig/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- 'eslint.config.mjs'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/cli-tests.yml'
|
||||
- '.github/workflows/cli-docker-build.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
@ -197,66 +184,6 @@ jobs:
|
||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
cli-tests-run:
|
||||
name: Run CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||
uses: ./.github/workflows/cli-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests-skip:
|
||||
name: Skip CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped CLI tests
|
||||
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
if: ${{ always() }}
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
- cli-tests-run
|
||||
- cli-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize CLI Tests status
|
||||
env:
|
||||
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||
run: |
|
||||
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests ran successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests were skipped because no CLI-related files changed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
web-tests-run:
|
||||
name: Run Web Tests
|
||||
needs:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -115,12 +115,6 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||
!/cli/src/env/
|
||||
!/cli/src/commands/env/
|
||||
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||
!/cli/scripts/lib/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
@ -253,7 +247,6 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
*.local.toml
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
|
||||
@ -159,7 +159,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -204,7 +203,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -520,44 +520,6 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
OPENAPI_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||
"programmatic clients. Set to true to activate; disabled by default."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -933,17 +895,6 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1230,14 +1181,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -1,120 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspacePayload,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
WorkflowRunData,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
AccountResponse,
|
||||
SessionRow,
|
||||
SessionListResponse,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
human_input_form,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_events,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
# Request models are imported from _models.py and registered above.
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_events",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
@ -1,66 +0,0 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||
|
||||
|
||||
def emit_app_run(
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
caller_kind: str,
|
||||
mode: str,
|
||||
surface: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
surface,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
"surface": surface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_wrong_surface(
|
||||
*,
|
||||
subject_type: str | None,
|
||||
attempted_path: str,
|
||||
client_id: str | None,
|
||||
token_id: str | None,
|
||||
) -> None:
|
||||
logger.warning(
|
||||
"audit: %s subject_type=%s attempted_path=%s",
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
subject_type,
|
||||
attempted_path,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
"subject_type": subject_type,
|
||||
"attempted_path": attempted_path,
|
||||
"client_id": client_id,
|
||||
"token_id": token_id,
|
||||
},
|
||||
)
|
||||
@ -1,143 +0,0 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
@ -1,319 +0,0 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class PermittedExternalAppsListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
class AccountPayload(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspacePayload(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
subject_type: str
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account: AccountPayload | None = None
|
||||
workspaces: list[WorkspacePayload] = []
|
||||
default_workspace_id: str | None = None
|
||||
|
||||
|
||||
class SessionRow(BaseModel):
|
||||
id: str
|
||||
prefix: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
created_at: str | None = None
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WorkspaceSummaryResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
workspaces: list[WorkspaceSummaryResponse]
|
||||
|
||||
|
||||
class WorkspaceDetailResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class DeviceCodeResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceLookupResponse(BaseModel):
|
||||
valid: bool
|
||||
expires_in_remaining: int = 0
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class DeviceMutateResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
workspace_id: UUIDStrOrEmpty | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class PermittedExternalAppsListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
@ -1,249 +0,0 @@
|
||||
"""User-scoped account endpoints. /account is the bearer-authed
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import and_, select, update
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
_revoke_token_by_id(str(ctx.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = db.session.execute(
|
||||
select(
|
||||
OAuthAccessToken.id,
|
||||
OAuthAccessToken.prefix,
|
||||
OAuthAccessToken.client_id,
|
||||
OAuthAccessToken.device_label,
|
||||
OAuthAccessToken.created_at,
|
||||
OAuthAccessToken.last_used_at,
|
||||
OAuthAccessToken.expires_at,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
*_subject_match(ctx),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
OAuthAccessToken.token_hash.is_not(None),
|
||||
OAuthAccessToken.expires_at > now,
|
||||
)
|
||||
)
|
||||
.order_by(OAuthAccessToken.created_at.desc())
|
||||
).all()
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
SessionRow(
|
||||
id=str(r.id),
|
||||
prefix=r.prefix,
|
||||
client_id=r.client_id,
|
||||
device_label=r.device_label,
|
||||
created_at=_iso(r.created_at),
|
||||
last_used_at=_iso(r.last_used_at),
|
||||
expires_at=_iso(r.expires_at),
|
||||
)
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if owns is None:
|
||||
raise NotFound("session not found")
|
||||
|
||||
_revoke_token_by_id(session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _subject_match(ctx: AuthContext) -> tuple:
|
||||
"""Where-clauses that scope a query to the bearer's subject. Works
|
||||
for both account (account_id) and external_sso (email + issuer).
|
||||
"""
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
return (
|
||||
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||
OAuthAccessToken.account_id.is_(None),
|
||||
)
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _revoke_token_by_id(token_id: str) -> None:
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> WorkspacePayload:
|
||||
join, tenant = row
|
||||
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||
|
||||
|
||||
def _account_payload(account) -> AccountPayload:
|
||||
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||
@ -1,165 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
@ -1,280 +0,0 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → sets `g.auth_ctx` before `require_scope` reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = g.auth_ctx
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal" or not is_openapi_visible(app):
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
apply_openapi_gate(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = g.auth_ctx
|
||||
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
if parsed_uuid is not None:
|
||||
app: App = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all",
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
env = AppListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=int(pagination.total),
|
||||
has_more=query.page * query.limit < int(pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,107 +0,0 @@
|
||||
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||
|
||||
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||
(public / sso_verified). License-gated: CE deploys never enable the
|
||||
EE blueprint chain so this module is unreachable there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
from services.openapi.visibility import apply_openapi_gate
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(apply_openapi_gate(sa.select(App).where(App.id.in_(page_result.app_ids))))
|
||||
.scalars()
|
||||
.all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=page_result.total,
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,3 +0,0 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
@ -1,46 +0,0 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
@ -1,46 +0,0 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: Scope
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
@ -1,41 +0,0 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
@ -1,172 +0,0 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also assigns
|
||||
``g.auth_ctx`` (the same way ``validate_bearer`` does) so the surface gate
|
||||
+ any handler reading the request-scoped context has a single source of
|
||||
truth across both auth-attach paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from flask import g
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
_extract_bearer, # type: ignore[attr-defined]
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also attaches the resolved `AuthContext` to ``g.auth_ctx`` — same shape
|
||||
the decorator-level ``validate_bearer`` writes — so the surface gate
|
||||
+ downstream readers don't see two different identity sources."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
|
||||
# Single source of truth for the request-scoped identity. Surface
|
||||
# gate + handlers read `g.auth_ctx` regardless of whether the route
|
||||
# ran the decorator path (`validate_bearer`) or the pipeline path.
|
||||
g.auth_ctx = authn
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if `g.auth_ctx.subject_type` is not in `accepted`.
|
||||
|
||||
Delegates to `surface_gate.check_surface` so the inline decorator and
|
||||
the pipeline step emit identical audit events. Relies on `BearerCheck`
|
||||
(above) having set `g.auth_ctx`.
|
||||
"""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
@ -1,184 +0,0 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = db.session.execute(
|
||||
select(Account).where(Account.email == ctx.subject_email),
|
||||
).scalar_one_or_none()
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
@ -1,89 +0,0 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that ``g.auth_ctx.subject_type`` is in ``accepted``.
|
||||
|
||||
Raises ``Forbidden`` with ``wrong_surface`` + canonical-path hint on
|
||||
miss; emits ``openapi.wrong_surface_denied`` audit. If ``g.auth_ctx``
|
||||
is missing the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without g.auth_ctx; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
try:
|
||||
return SubjectType(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
@ -1,107 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed human input form endpoints.
|
||||
|
||||
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
|
||||
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App
|
||||
from services.human_input_service import FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
if caller_kind == "account":
|
||||
submission_user_id = caller.id
|
||||
else:
|
||||
submission_end_user_id = caller.id
|
||||
|
||||
if form.recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_token=%s", form_token)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
@ -1,9 +0,0 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
@ -1,404 +0,0 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Validation helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
@ -1,369 +0,0 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from controllers.openapi import bp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims["email"]):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if _email_belongs_to_dify_account(claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RejectedClaims:
|
||||
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||
|
||||
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||
event without reaching for the full dataclass.
|
||||
"""
|
||||
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _email_belongs_to_dify_account(email: str) -> bool:
|
||||
"""External SSO subjects whose email matches an active Dify Account must
|
||||
authenticate via the internal Dify login path (which mints dfoa_), not via
|
||||
the external SSO device flow. Returning True here blocks dfoe_ minting.
|
||||
|
||||
Pending/uninitialized/banned/closed accounts do not block: pending and
|
||||
uninitialized users may complete invitation via SSO; banned and closed
|
||||
accounts are handled by separate enforcement paths.
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
normalized = email.strip().lower()
|
||||
if not normalized:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(Account.id).where(
|
||||
func.lower(Account.email) == normalized,
|
||||
Account.status == AccountStatus.ACTIVE,
|
||||
),
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
reason,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_rejected",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"reason": reason,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
@ -1,119 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed workflow reconnect event stream endpoint.
|
||||
|
||||
GET /apps/<app_id>/tasks/<task_id>/events
|
||||
— reconnect to the SSE stream for a paused/running workflow run.
|
||||
`task_id` is treated as `workflow_run_id`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if caller_kind == "account":
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
else:
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=caller,
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
else:
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.OPENAPI,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -1,90 +0,0 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
rows = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == str(ctx.account_id))
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
row = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
Tenant.id == workspace_id,
|
||||
TenantAccountJoin.account_id == str(ctx.account_id),
|
||||
)
|
||||
).first()
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
)
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||
return WorkspaceDetailResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
)
|
||||
@ -16,7 +16,7 @@ from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
@ -88,8 +88,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
|
||||
!= WebAppAccessMode.PUBLIC
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
@ -731,8 +731,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -24,7 +24,6 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
@ -43,7 +42,6 @@ class InvokeFrom(StrEnum):
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
InvokeFrom.OPENAPI: "openapi",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ def _get_surface_form_token(
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
@ -11,15 +11,13 @@ from models.human_input import RecipientType
|
||||
class HumanInputSurface(StrEnum):
|
||||
SERVICE_API = "service_api"
|
||||
CONSOLE = "console"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
|
||||
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
|
||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
}
|
||||
|
||||
# A single HITL form can have multiple recipient records; this shared priority
|
||||
|
||||
@ -45,7 +45,6 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
|
||||
)
|
||||
|
||||
|
||||
@ -162,8 +161,6 @@ def create_spec_app() -> Flask:
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.console import console_ns
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web import bp as web_bp
|
||||
@ -172,9 +169,8 @@ def create_spec_app() -> Flask:
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
|
||||
for namespace in (console_ns, web_ns, service_api_ns):
|
||||
for api in namespace.apis:
|
||||
_materialize_inline_model_definitions(api)
|
||||
|
||||
@ -205,13 +201,6 @@ def _registered_models(namespace: str) -> dict[str, object]:
|
||||
for api in service_api_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "openapi":
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
models = dict(openapi_ns.models)
|
||||
for api in openapi_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
|
||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||
|
||||
|
||||
@ -8,8 +8,6 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
|
||||
OPENAPI_MAX_AGE_SECONDS: int = 600
|
||||
|
||||
|
||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||
@ -31,7 +29,6 @@ def init_app(app: DifyApp):
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.mcp import bp as mcp_bp
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
@ -44,23 +41,6 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
if dify_config.OPENAPI_ENABLED:
|
||||
# User-scoped programmatic API. Default empty allowlist = same-origin
|
||||
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
|
||||
# integrations. supports_credentials so cookie-authed approve/deny
|
||||
# work; cross-origin OPTIONS without an allowed origin will fail
|
||||
# the same as on the console blueprint.
|
||||
_apply_cors_once(
|
||||
openapi_bp,
|
||||
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={
|
||||
|
||||
@ -222,12 +222,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
|
||||
from libs.token import extract_access_token, extract_webapp_passport
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -84,24 +84,6 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
# approval-context) sit under @login_required and authenticate via
|
||||
# the console session cookie. Cookie-only on purpose — bearer
|
||||
# tokens (dfoa_/dfoe_) live on the Authorization header and are
|
||||
# validated by AppPipeline, not flask-login.
|
||||
cookie_token = extract_console_cookie_token(request)
|
||||
if not cookie_token:
|
||||
return None
|
||||
try:
|
||||
decoded = PassportService().verify(cookie_token)
|
||||
except Exception:
|
||||
return None
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import build_and_bind
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.ENABLE_OAUTH_BEARER
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
# scoped_session isn't a context manager; request teardown closes it.
|
||||
def session_factory():
|
||||
return db.session
|
||||
|
||||
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||
@ -1,196 +0,0 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
|
||||
from flask import Blueprint
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
from libs.token import is_secure
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# enterprise_only decorator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
||||
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
||||
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
||||
|
||||
|
||||
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||
responses don't consume the bucket.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status not in _EE_ENABLED_STATUSES:
|
||||
raise NotFound()
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# approval_grant cookie
|
||||
# ============================================================================
|
||||
|
||||
|
||||
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
||||
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ApprovalGrantClaims:
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
user_code: str
|
||||
nonce: str
|
||||
csrf_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def mint_approval_grant(
|
||||
*,
|
||||
keyset: jws.KeySet,
|
||||
iss: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str,
|
||||
user_code: str,
|
||||
) -> tuple[str, ApprovalGrantClaims]:
|
||||
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
nonce = _random_opaque()
|
||||
csrf_token = _random_opaque()
|
||||
|
||||
payload = {
|
||||
"iss": iss,
|
||||
"subject_email": subject_email,
|
||||
"subject_issuer": subject_issuer,
|
||||
"user_code": user_code,
|
||||
"nonce": nonce,
|
||||
"csrf_token": csrf_token,
|
||||
}
|
||||
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
|
||||
return token, ApprovalGrantClaims(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
user_code=user_code,
|
||||
nonce=nonce,
|
||||
csrf_token=csrf_token,
|
||||
expires_at=exp,
|
||||
)
|
||||
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
user_code=data["user_code"],
|
||||
nonce=data["nonce"],
|
||||
csrf_token=data["csrf_token"],
|
||||
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||
silently drop the cookie.
|
||||
"""
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": value,
|
||||
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": "",
|
||||
"max_age": 0,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def _random_opaque() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Anti-framing headers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_ANTI_FRAMING_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||
}
|
||||
|
||||
|
||||
def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
@ -76,7 +76,6 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
current_app.logger.exception("value_error in request handler")
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
@ -578,18 +578,3 @@ class RateLimiter:
|
||||
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
def seconds_until_available(self, email: str) -> int:
|
||||
"""Seconds until the oldest in-window entry expires, freeing a slot.
|
||||
|
||||
Defensive floor of 1 second. Caller should only invoke this after
|
||||
is_rate_limited() returned True.
|
||||
"""
|
||||
key = self._get_key(email)
|
||||
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
|
||||
if not oldest:
|
||||
return 1
|
||||
_member, score = oldest[0]
|
||||
free_at = int(score) + self.time_window
|
||||
remaining = free_at - int(time.time())
|
||||
return max(remaining, 1)
|
||||
|
||||
108
api/libs/jws.py
108
api/libs/jws.py
@ -1,108 +0,0 @@
|
||||
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
||||
state envelope, external subject assertion, and approval-grant cookie —
|
||||
all three share one key-set so api ↔ enterprise can verify each other.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
||||
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
||||
|
||||
ACTIVE_KID_V1 = "dify-shared-v1"
|
||||
|
||||
|
||||
class KeySetError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KeySet:
|
||||
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
||||
|
||||
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
||||
if active_kid not in entries:
|
||||
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
||||
if not entries[active_kid]:
|
||||
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
||||
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
||||
self._active_kid = active_kid
|
||||
|
||||
@classmethod
|
||||
def from_shared_secret(cls) -> KeySet:
|
||||
secret = dify_config.SECRET_KEY
|
||||
if not secret:
|
||||
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||
|
||||
@classmethod
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
||||
return cls(entries, active_kid)
|
||||
|
||||
@property
|
||||
def active_kid(self) -> str:
|
||||
return self._active_kid
|
||||
|
||||
def lookup(self, kid: str) -> bytes | None:
|
||||
return self._entries.get(kid)
|
||||
|
||||
|
||||
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
||||
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
||||
if "aud" in payload or "iat" in payload or "exp" in payload:
|
||||
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds must be positive")
|
||||
|
||||
kid = keyset.active_kid
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise KeySetError(f"active kid {kid!r} lookup miss")
|
||||
|
||||
iat = datetime.now(UTC)
|
||||
exp = iat + timedelta(seconds=ttl_seconds)
|
||||
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
||||
return jwt.encode(
|
||||
claims,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"kid": kid, "typ": "JWT"},
|
||||
)
|
||||
|
||||
|
||||
class VerifyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
||||
"""Unknown kid is rejected — never fall back to the active kid, since
|
||||
a past kid value would otherwise be forgeable by anyone who saw it.
|
||||
"""
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode header: {e}") from e
|
||||
kid = header.get("kid")
|
||||
if not kid:
|
||||
raise VerifyError("no kid in header")
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise VerifyError(f"unknown kid {kid!r}")
|
||||
try:
|
||||
return jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience=expected_aud,
|
||||
)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise VerifyError("token expired") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
raise VerifyError("aud mismatch") from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode: {e}") from e
|
||||
@ -1,650 +0,0 @@
|
||||
"""OAuth bearer primitives.
|
||||
|
||||
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||
Authenticator + validate_bearer stay untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Literal, ParamSpec, Protocol, TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.rate_limit import enforce_bearer_rate_limit
|
||||
from models import Account, OAuthAccessToken, TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Contract — types, enums, protocols
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SubjectType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
class Scope(StrEnum):
|
||||
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||
|
||||
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
|
||||
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
|
||||
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
|
||||
"""
|
||||
|
||||
FULL = "full"
|
||||
APPS_READ = "apps:read"
|
||||
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
|
||||
APPS_RUN = "apps:run"
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
"""Subject types a route is willing to accept as caller."""
|
||||
|
||||
USER_ACCOUNT = "user_account"
|
||||
USER_EXT_SSO = "user_ext_sso"
|
||||
|
||||
|
||||
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
||||
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
|
||||
|
||||
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
||||
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthContext:
|
||||
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
|
||||
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
|
||||
|
||||
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||
authenticate time. Per-request mutations write through to Redis via
|
||||
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
client_id: str | None
|
||||
scopes: frozenset[Scope]
|
||||
token_id: uuid.UUID
|
||||
source: str
|
||||
expires_at: datetime | None
|
||||
token_hash: str
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedRow:
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
client_id: str | None
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime | None
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def to_cache(self) -> dict:
|
||||
return {
|
||||
"subject_email": self.subject_email,
|
||||
"subject_issuer": self.subject_issuer,
|
||||
"account_id": str(self.account_id) if self.account_id else None,
|
||||
"client_id": self.client_id,
|
||||
"token_id": str(self.token_id),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"verified_tenants": dict(self.verified_tenants),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_cache(cls, data: dict) -> ResolvedRow:
|
||||
return cls(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
||||
client_id=data.get("client_id"),
|
||||
token_id=uuid.UUID(data["token_id"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
||||
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
|
||||
)
|
||||
|
||||
|
||||
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
|
||||
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
|
||||
|
||||
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
|
||||
on all live deployments (60s TTL → safe to drop one release after rollout).
|
||||
"""
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
out: dict[str, bool] = {}
|
||||
for k, v in raw.items():
|
||||
if isinstance(v, bool):
|
||||
out[k] = v
|
||||
elif v == "ok":
|
||||
out[k] = True
|
||||
elif v == "denied":
|
||||
out[k] = False
|
||||
return out
|
||||
|
||||
|
||||
class Resolver(Protocol):
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[Scope]
|
||||
source: str
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
return token.startswith(self.prefix)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintProfile:
|
||||
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
|
||||
|
||||
Consumers:
|
||||
- ``build_registry`` reads scopes here so the resolve-time TokenKind
|
||||
cannot drift from the mint-time intent.
|
||||
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
|
||||
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
|
||||
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
|
||||
the (subject_type, prefix, scopes) triple a caller intends to mint
|
||||
against this table — a caller that assembles its own scope set
|
||||
from a non-canonical source will fail closed at approve time.
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
prefix: str
|
||||
scopes: frozenset[Scope]
|
||||
|
||||
|
||||
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
|
||||
SubjectType.ACCOUNT: MintProfile(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
prefix="dfoa_",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: MintProfile(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
prefix="dfoe_",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class InvalidBearerError(Exception):
|
||||
"""Token missing, unknown prefix, or no live row."""
|
||||
|
||||
|
||||
class TokenExpiredError(Exception):
|
||||
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenKindRegistry:
|
||||
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
||||
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
||||
prefixes = [k.prefix for k in self._kinds]
|
||||
if len(set(prefixes)) != len(prefixes):
|
||||
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
||||
|
||||
def find(self, token: str) -> TokenKind | None:
|
||||
for k in self._kinds:
|
||||
if k.matches(token):
|
||||
return k
|
||||
return None
|
||||
|
||||
def kinds(self) -> tuple[TokenKind, ...]:
|
||||
return self._kinds
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authenticator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class BearerAuthenticator:
|
||||
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def registry(self) -> TokenKindRegistry:
|
||||
return self._registry
|
||||
|
||||
def authenticate(self, token: str) -> AuthContext:
|
||||
"""Identity + per-token rate limit (single source).
|
||||
|
||||
Both the openapi pipeline (`BearerCheck`) and the decorator
|
||||
(`validate_bearer`) call this — rate-limit fires exactly once per
|
||||
request regardless of which path hosts the route.
|
||||
"""
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearerError("unknown token prefix")
|
||||
token_hash = sha256_hex(token)
|
||||
row = kind.resolver.resolve(token_hash)
|
||||
if row is None:
|
||||
raise InvalidBearerError("token unknown or revoked")
|
||||
enforce_bearer_rate_limit(token_hash)
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=row.account_id,
|
||||
client_id=row.client_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
source=kind.source,
|
||||
expires_at=row.expires_at,
|
||||
token_hash=token_hash,
|
||||
verified_tenants=dict(row.verified_tenants),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth access token resolver (PAT resolver would be a sibling class)
|
||||
# ============================================================================
|
||||
|
||||
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
||||
POSITIVE_TTL_SECONDS = 60
|
||||
NEGATIVE_TTL_SECONDS = 10
|
||||
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
||||
|
||||
ScopeVariant = Literal["account", "external_sso"]
|
||||
|
||||
|
||||
class OAuthAccessTokenResolver:
|
||||
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
||||
sharing DB + cache plumbing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory,
|
||||
redis_client,
|
||||
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._redis = redis_client
|
||||
self._positive_ttl = positive_ttl
|
||||
self._negative_ttl = negative_ttl
|
||||
|
||||
def for_account(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="account")
|
||||
|
||||
def for_external_sso(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="external_sso")
|
||||
|
||||
def _cache_key(self, token_hash: str) -> str:
|
||||
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
|
||||
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
raw = self._redis.get(self._cache_key(token_hash))
|
||||
if raw is None:
|
||||
return None
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return "invalid"
|
||||
try:
|
||||
return ResolvedRow.from_cache(json.loads(text))
|
||||
except (ValueError, KeyError):
|
||||
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||
return None
|
||||
|
||||
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
self._redis.setex(
|
||||
self._cache_key(token_hash),
|
||||
self._positive_ttl,
|
||||
json.dumps(row.to_cache()),
|
||||
)
|
||||
|
||||
def cache_set_negative(self, token_hash: str) -> None:
|
||||
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||
|
||||
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
||||
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||
replays are idempotent.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
if result.rowcount == 1:
|
||||
logger.warning(
|
||||
"audit: %s token_id=%s",
|
||||
AUDIT_OAUTH_EXPIRED,
|
||||
row_id,
|
||||
extra={"audit": True, "token_id": str(row_id)},
|
||||
)
|
||||
self._redis.delete(self._cache_key(token_hash))
|
||||
self.cache_set_negative(token_hash)
|
||||
|
||||
|
||||
class _VariantResolver:
|
||||
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||
self._parent = parent
|
||||
self._variant = variant
|
||||
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||
cached = self._parent.cache_get(token_hash)
|
||||
if cached == "invalid":
|
||||
return None
|
||||
if cached is not None and not isinstance(cached, str):
|
||||
if not self._matches_variant(cached):
|
||||
return None
|
||||
return cached
|
||||
|
||||
# Flask-SQLAlchemy's scoped_session is request-bound and not a
|
||||
# context manager; use it directly.
|
||||
session = self._parent.session_factory()
|
||||
row = self._load_from_db(session, token_hash)
|
||||
if row is None:
|
||||
self._parent.cache_set_negative(token_hash)
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if row.expires_at is not None and row.expires_at <= now:
|
||||
self._parent.hard_expire(session, row.id, token_hash)
|
||||
return None
|
||||
|
||||
if not self._matches_variant_model(row):
|
||||
logger.error(
|
||||
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||
row.id,
|
||||
row.prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
resolved = ResolvedRow(
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
||||
client_id=row.client_id,
|
||||
token_id=uuid.UUID(str(row.id)),
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
self._parent.cache_set_positive(token_hash, resolved)
|
||||
return resolved
|
||||
|
||||
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account
|
||||
return not has_account
|
||||
|
||||
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account and row.prefix == "dfoa_"
|
||||
return (not has_account) and row.prefix == "dfoe_"
|
||||
|
||||
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
||||
return (
|
||||
session.query(OAuthAccessToken)
|
||||
.filter(
|
||||
OAuthAccessToken.token_hash == token_hash,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Layer 0 — workspace membership cache + helper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
|
||||
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
|
||||
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
|
||||
rebuilds via authenticate() and re-runs Layer 0.
|
||||
"""
|
||||
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
raw = redis_client.get(cache_key)
|
||||
if raw is None:
|
||||
return
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
ttl = redis_client.ttl(cache_key)
|
||||
if ttl <= 0:
|
||||
return
|
||||
data.setdefault("verified_tenants", {})[tenant_id] = verdict
|
||||
redis_client.setex(cache_key, ttl, json.dumps(data))
|
||||
|
||||
|
||||
def check_workspace_membership(
|
||||
*,
|
||||
account_id: uuid.UUID | str,
|
||||
tenant_id: str,
|
||||
token_hash: str,
|
||||
cached_verdicts: dict[str, bool],
|
||||
) -> None:
|
||||
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||
|
||||
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
|
||||
inline helper (`require_workspace_member`). Caller is responsible for
|
||||
short-circuiting on EE / SSO subjects before invoking — this function
|
||||
runs the membership + active-status checks unconditionally.
|
||||
"""
|
||||
cached = cached_verdicts.get(tenant_id)
|
||||
if cached is True:
|
||||
return
|
||||
if cached is False:
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
join = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if join is None:
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
|
||||
if status != "active":
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
record_layer0_verdict(token_hash, tenant_id, True)
|
||||
|
||||
|
||||
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
|
||||
|
||||
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
|
||||
(no `tenant_account_joins` row by definition).
|
||||
"""
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=tenant_id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.verified_tenants,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decorator — route-level bearer gate
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_authenticator: BearerAuthenticator | None = None
|
||||
|
||||
|
||||
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
||||
global _authenticator
|
||||
_authenticator = authenticator
|
||||
|
||||
|
||||
def get_authenticator() -> BearerAuthenticator:
|
||||
if _authenticator is None:
|
||||
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
||||
return _authenticator
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
header = req.headers.get("Authorization", "")
|
||||
scheme, _, value = header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not value:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
_DP = ParamSpec("_DP")
|
||||
_DR = TypeVar("_DR")
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
||||
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||
|
||||
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
||||
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
|
||||
and are rejected here as the wrong auth scheme for this surface.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
||||
token = _extract_bearer(request)
|
||||
if token is None:
|
||||
raise Unauthorized("missing bearer token")
|
||||
|
||||
if _authenticator is None:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
|
||||
try:
|
||||
ctx = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||
raise Forbidden("token subject type not accepted here")
|
||||
|
||||
g.auth_ctx = ctx
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||
without the authenticator, so fail fast instead of approving silently.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def require_scope(scope: Scope) -> Callable:
|
||||
"""Route-level scope gate — must run AFTER validate_bearer so that
|
||||
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
|
||||
when the bearer lacks both the requested scope and `Scope.FULL`.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
||||
)
|
||||
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
|
||||
raise Forbidden(f"insufficient_scope: {scope}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Wiring — called once from the app factory
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
return TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix=account.prefix,
|
||||
subject_type=account.subject_type,
|
||||
scopes=account.scopes,
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix=external.prefix,
|
||||
subject_type=external.subject_type,
|
||||
scopes=external.scopes,
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||
registry = build_registry(session_factory, redis_client)
|
||||
auth = BearerAuthenticator(registry)
|
||||
bind_authenticator(auth)
|
||||
return auth
|
||||
@ -1,140 +0,0 @@
|
||||
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
||||
window Redis ZSET). Apply after auth decorators so scopes can read
|
||||
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
|
||||
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
||||
generic 429.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import g, jsonify, make_response, request, session
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
|
||||
|
||||
class RateLimitScope(StrEnum):
|
||||
IP = "ip"
|
||||
SESSION = "session"
|
||||
ACCOUNT = "account"
|
||||
SUBJECT_EMAIL = "subject_email"
|
||||
TOKEN_ID = "token_id"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimit:
|
||||
limit: int
|
||||
window: timedelta
|
||||
scopes: tuple[RateLimitScope, ...]
|
||||
|
||||
|
||||
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
||||
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
||||
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
||||
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_BEARER_PER_TOKEN = RateLimit(
|
||||
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
|
||||
window=timedelta(minutes=1),
|
||||
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
|
||||
)
|
||||
|
||||
|
||||
def _one_key(scope: RateLimitScope) -> str:
|
||||
match scope:
|
||||
case RateLimitScope.IP:
|
||||
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
||||
case RateLimitScope.SESSION:
|
||||
return f"session:{session.get('_id', 'anon')}"
|
||||
case RateLimitScope.ACCOUNT:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.account_id:
|
||||
return f"account:{ctx.account_id}"
|
||||
return "account:anon"
|
||||
case RateLimitScope.SUBJECT_EMAIL:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.subject_email:
|
||||
return f"subject:{ctx.subject_email}"
|
||||
return "subject:anon"
|
||||
case RateLimitScope.TOKEN_ID:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.token_id:
|
||||
return f"token:{ctx.token_id}"
|
||||
return "token:anon"
|
||||
|
||||
|
||||
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "|".join(_one_key(s) for s in scopes)
|
||||
|
||||
|
||||
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "rl:" + "+".join(s.value for s in scopes)
|
||||
|
||||
|
||||
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||
return RateLimiter(
|
||||
prefix=_limiter_prefix(spec.scopes),
|
||||
max_attempts=spec.limit,
|
||||
time_window=int(spec.window.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Apply after auth decorators that the scopes read from."""
|
||||
limiter = _build_limiter(spec)
|
||||
|
||||
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
key = _composite_key(spec.scopes)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def enforce(spec: RateLimit, *, key: str) -> None:
|
||||
"""Imperative form — caller composes the bucket key to match scope
|
||||
semantics (the key is opaque here).
|
||||
"""
|
||||
limiter = _build_limiter(spec)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
|
||||
|
||||
def enforce_bearer_rate_limit(token_hash: str) -> None:
|
||||
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
|
||||
|
||||
Bucket key = ``token:<sha256_hex>`` so the same token shares one
|
||||
bucket across api replicas (Redis-backed sliding window).
|
||||
"""
|
||||
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
|
||||
key = f"token:{token_hash}"
|
||||
if limiter.is_rate_limited(key):
|
||||
retry_after = limiter.seconds_until_available(key)
|
||||
response = make_response(
|
||||
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
|
||||
429,
|
||||
)
|
||||
response.headers["Retry-After"] = str(retry_after)
|
||||
raise TooManyRequests(response=response)
|
||||
limiter.increment_rate_limit(key)
|
||||
@ -72,15 +72,11 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||
|
||||
|
||||
def extract_console_cookie_token(request: Request) -> str | None:
|
||||
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
|
||||
approval routes, which must not fall through to the Authorization header
|
||||
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
return extract_console_cookie_token(request) or _try_extract_from_header(request)
|
||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_access_token(request: Request) -> str | None:
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
"""add oauth_access_tokens table
|
||||
|
||||
Revision ID: d4a5e1f3c9b7
|
||||
Revises: a4f2d8c9b731
|
||||
Create Date: 2026-04-23 22:00:00.000000
|
||||
|
||||
Table stores user-level OAuth bearer tokens minted via the device-flow grant
|
||||
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
|
||||
table not added in this migration.
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d4a5e1f3c9b7"
|
||||
down_revision = "a4f2d8c9b731"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"oauth_access_tokens",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column("subject_email", sa.Text(), nullable=False),
|
||||
sa.Column("subject_issuer", sa.Text(), nullable=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("device_label", sa.Text(), nullable=False),
|
||||
sa.Column("prefix", sa.String(length=8), nullable=False),
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["account_id"],
|
||||
["accounts.id"],
|
||||
name="fk_oauth_access_tokens_account_id",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"idx_oauth_subject_email",
|
||||
"oauth_access_tokens",
|
||||
["subject_email"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_account",
|
||||
"oauth_access_tokens",
|
||||
["account_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_client",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "client_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_token_hash",
|
||||
"oauth_access_tokens",
|
||||
["token_hash"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
# Partial unique index — rotate-in-place keyed on (subject, client, device).
|
||||
# The app always writes a non-NULL subject_issuer (account flow uses a
|
||||
# sentinel, external-SSO uses the verified IdP issuer); without that the
|
||||
# composite key would never collide because Postgres treats NULLs as
|
||||
# distinct in unique indices.
|
||||
op.create_index(
|
||||
"uq_oauth_active_per_device",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
|
||||
op.drop_table("oauth_access_tokens")
|
||||
@ -73,7 +73,7 @@ from .model import (
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -177,7 +177,6 @@ __all__ = [
|
||||
"MessageChain",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OAuthAccessToken",
|
||||
"OperationLog",
|
||||
"PinnedConversation",
|
||||
"Provider",
|
||||
|
||||
@ -185,7 +185,6 @@ class InvokeFrom(StrEnum):
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
@ -198,7 +197,6 @@ class InvokeFrom(StrEnum):
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
InvokeFrom.OPENAPI: "openapi",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
@ -84,35 +84,3 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(TypeBase):
|
||||
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
|
||||
subject_issuer = "dify:account" sentinel); account_id NULL +
|
||||
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
|
||||
subject_issuer is non-NULL for all rows the app writes — Postgres
|
||||
treats NULLs as distinct in unique indices, so the partial unique
|
||||
index on (subject_email, subject_issuer, client_id, device_label)
|
||||
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_access_tokens"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
|
||||
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||
)
|
||||
|
||||
@ -1209,7 +1209,6 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
INSTALLED_APP = "installed-app"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
|
||||
(token_id stays for audits) so rows accumulate across re-logins, and
|
||||
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import click
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 500
|
||||
|
||||
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_oauth_access_tokens_task():
|
||||
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
|
||||
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
candidates = or_(
|
||||
OAuthAccessToken.revoked_at < cutoff,
|
||||
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
|
||||
)
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
|
||||
if not ids:
|
||||
break
|
||||
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
|
||||
db.session.commit()
|
||||
total += len(ids)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -39,8 +39,6 @@ class AppListParams(BaseModel):
|
||||
name: str | None = None
|
||||
tag_ids: list[str] | None = None
|
||||
is_created_by_me: bool | None = None
|
||||
status: str | None = None
|
||||
openapi_visible: bool = False
|
||||
|
||||
|
||||
class CreateAppParams(BaseModel):
|
||||
@ -77,14 +75,6 @@ class AppService:
|
||||
elif params.mode == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if params.status:
|
||||
filters.append(App.status == params.status)
|
||||
# OpenAPI surface visibility gate. Pushed into the query so
|
||||
# `pagination.total` reflects only apps the openapi caller can
|
||||
# actually reach — post-filtering by enable_api after the page
|
||||
# arrives would make `total` page-dependent.
|
||||
if params.openapi_visible:
|
||||
filters.append(App.enable_api.is_(True))
|
||||
if params.is_created_by_me:
|
||||
filters.append(App.created_by == user_id)
|
||||
if params.name:
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PermittedAppsPage:
|
||||
app_ids: list[str]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
def list_permitted_apps(
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> PermittedAppsPage:
|
||||
try:
|
||||
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
|
||||
page=page, limit=limit, mode=mode, name=name
|
||||
)
|
||||
except EnterpriseAPIError as exc:
|
||||
logger.warning(
|
||||
"permitted_apps EE call failed: status=%s message=%s",
|
||||
getattr(exc, "status_code", None),
|
||||
str(exc),
|
||||
)
|
||||
raise ServiceUnavailable("permitted_apps_unavailable") from exc
|
||||
|
||||
return PermittedAppsPage(
|
||||
app_ids=[row["appId"] for row in body.get("data", [])],
|
||||
total=int(body.get("total", 0)),
|
||||
has_more=bool(body.get("hasMore", False)),
|
||||
)
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
@ -25,22 +24,10 @@ VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppAccessMode(enum.StrEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
PRIVATE_ALL = "private_all"
|
||||
SSO_VERIFIED = "sso_verified"
|
||||
|
||||
|
||||
PERMISSION_CHECK_MODES: frozenset[WebAppAccessMode] = frozenset(
|
||||
{WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||
)
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
description=f"Access mode for the web app. One of: {', '.join(m.value for m in WebAppAccessMode)}",
|
||||
default=WebAppAccessMode.PRIVATE.value,
|
||||
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'",
|
||||
default="private",
|
||||
alias="accessMode",
|
||||
)
|
||||
|
||||
@ -121,15 +108,6 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/device-flow/sso-initiate",
|
||||
json={"signed_state": signed_state},
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
@ -241,9 +219,8 @@ class EnterpriseService:
|
||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||
if not app_id:
|
||||
raise ValueError("app_id must be provided.")
|
||||
allowed = {WebAppAccessMode.PUBLIC, WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
|
||||
if access_mode not in allowed:
|
||||
raise ValueError(f"access_mode must be one of: {', '.join(m.value for m in allowed)}")
|
||||
if access_mode not in ["public", "private", "private_all"]:
|
||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||
|
||||
data = {"appId": app_id, "accessMode": access_mode}
|
||||
|
||||
@ -259,32 +236,6 @@ class EnterpriseService:
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def list_externally_accessible_apps(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> dict:
|
||||
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
|
||||
|
||||
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
|
||||
"total": int, "hasMore": bool}``.
|
||||
"""
|
||||
body: dict[str, str | int] = {"page": page, "limit": limit}
|
||||
if mode is not None:
|
||||
body["mode"] = mode
|
||||
if name is not None:
|
||||
body["name"] = name
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/webapp/externally-accessible-apps",
|
||||
json=body,
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
@ -1,467 +0,0 @@
|
||||
"""Device-flow service layer: Redis state machine, OAuth token mint
|
||||
(DB upsert + plaintext generation), and TTL policy. Specs:
|
||||
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Redis state machine — device_code + user_code ephemeral state
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_DEVICE_CODE_KEY_PREFIX = "device_code:"
|
||||
_USER_CODE_KEY_PREFIX = "user_code:"
|
||||
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
|
||||
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
|
||||
|
||||
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
|
||||
# not both observe APPROVED — only the winner gets the plaintext token,
|
||||
# the loser sees nil and the caller maps that to expired_token.
|
||||
_CONSUME_ON_POLL_LUA = """
|
||||
local raw = redis.call('GET', KEYS[1])
|
||||
if not raw then return nil end
|
||||
local ok, decoded = pcall(cjson.decode, raw)
|
||||
if not ok then return nil end
|
||||
if decoded.status == 'pending' then return nil end
|
||||
if decoded.user_code then
|
||||
redis.call('DEL', ARGV[1] .. decoded.user_code)
|
||||
end
|
||||
redis.call('DEL', KEYS[1])
|
||||
return raw
|
||||
"""
|
||||
|
||||
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
|
||||
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||
USER_CODE_SEGMENT_LEN = 4
|
||||
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
|
||||
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
|
||||
|
||||
|
||||
class DeviceFlowStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
class SlowDownDecision(StrEnum):
|
||||
OK = "ok"
|
||||
SLOW_DOWN = "slow_down"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlowState:
|
||||
"""``minted_token`` is plaintext between approve and the next poll;
|
||||
DEL'd after the poll reads it.
|
||||
"""
|
||||
|
||||
user_code: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
status: DeviceFlowStatus
|
||||
subject_email: str | None = None
|
||||
account_id: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
minted_token: str | None = None
|
||||
token_id: str | None = None
|
||||
created_at: str = ""
|
||||
created_ip: str = ""
|
||||
last_poll_at: str = ""
|
||||
poll_payload: dict | None = field(default=None)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> DeviceFlowState:
|
||||
data = json.loads(raw)
|
||||
if "status" in data:
|
||||
data["status"] = DeviceFlowStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
def _random_device_code() -> str:
|
||||
return "dc_" + secrets.token_urlsafe(24)
|
||||
|
||||
|
||||
def _random_user_code_segment() -> str:
|
||||
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
|
||||
|
||||
|
||||
def _random_user_code() -> str:
|
||||
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||
|
||||
|
||||
class StateNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTransitionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserCodeExhaustedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceFlowRedis:
|
||||
def __init__(self, redis_client) -> None:
|
||||
self._redis = redis_client
|
||||
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
|
||||
|
||||
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
|
||||
device_code = _random_device_code()
|
||||
user_code = self._claim_user_code(device_code)
|
||||
state = DeviceFlowState(
|
||||
user_code=user_code,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
status=DeviceFlowStatus.PENDING,
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
created_ip=created_ip,
|
||||
)
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
state.to_json(),
|
||||
)
|
||||
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
|
||||
|
||||
def _claim_user_code(self, device_code: str) -> str:
|
||||
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
|
||||
user_code = _random_user_code()
|
||||
key = USER_CODE_KEY_FMT.format(code=user_code)
|
||||
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||
if ok:
|
||||
return user_code
|
||||
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
|
||||
|
||||
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||
if not raw_dc:
|
||||
return None
|
||||
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
return None
|
||||
return device_code, state
|
||||
|
||||
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
|
||||
return self._load_state(device_code)
|
||||
|
||||
def _load_state(self, device_code: str) -> DeviceFlowState | None:
|
||||
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if not raw:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state for %s", device_code)
|
||||
return None
|
||||
|
||||
def approve(
|
||||
self,
|
||||
device_code: str,
|
||||
subject_email: str,
|
||||
account_id: str | None,
|
||||
minted_token: str,
|
||||
token_id: str,
|
||||
subject_issuer: str | None = None,
|
||||
poll_payload: dict | None = None,
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot approve {state.status}")
|
||||
|
||||
state.status = DeviceFlowStatus.APPROVED
|
||||
state.subject_email = subject_email
|
||||
state.account_id = account_id
|
||||
state.subject_issuer = subject_issuer
|
||||
state.minted_token = minted_token
|
||||
state.token_id = token_id
|
||||
state.poll_payload = poll_payload
|
||||
|
||||
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
|
||||
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
|
||||
|
||||
def deny(self, device_code: str) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot deny {state.status}")
|
||||
state.status = DeviceFlowStatus.DENIED
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
self._remaining_ttl(device_code, floor=1),
|
||||
state.to_json(),
|
||||
)
|
||||
|
||||
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
|
||||
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
|
||||
single Redis transaction so only one of N concurrent pollers
|
||||
observes the APPROVED state. Losers get None, mapped to
|
||||
expired_token by the caller.
|
||||
"""
|
||||
raw = self._consume_on_poll_script(
|
||||
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
|
||||
args=[_USER_CODE_KEY_PREFIX],
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state on consume %s", device_code)
|
||||
return None
|
||||
|
||||
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||
now = time.time()
|
||||
key = f"device_code:{device_code}:last_poll"
|
||||
prev_raw = self._redis.get(key)
|
||||
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
|
||||
if prev_raw is None:
|
||||
return SlowDownDecision.OK
|
||||
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
|
||||
try:
|
||||
prev = float(prev_s)
|
||||
except ValueError:
|
||||
return SlowDownDecision.OK
|
||||
if now - prev < interval_seconds:
|
||||
return SlowDownDecision.SLOW_DOWN
|
||||
return SlowDownDecision.OK
|
||||
|
||||
def _remaining_ttl(self, device_code: str, floor: int) -> int:
|
||||
"""``max(remaining, floor)`` — guarantees the CLI has at least
|
||||
``floor`` seconds to poll after a near-expiry approve.
|
||||
"""
|
||||
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if ttl is None or ttl < 0:
|
||||
return floor
|
||||
return max(int(ttl), floor)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token mint — generate + upsert
|
||||
# ============================================================================
|
||||
|
||||
|
||||
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
|
||||
PREFIX_OAUTH_ACCOUNT = "dfoa_"
|
||||
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
|
||||
|
||||
# Sentinel issuer for account-flow rows. Postgres' default partial unique
|
||||
# index treats NULLs as distinct, which would let two live `dfoa_` rows
|
||||
# share (email, client, device) and break rotate-in-place. Storing a
|
||||
# non-empty literal makes the composite key collide as intended.
|
||||
ACCOUNT_ISSUER_SENTINEL = "dify:account"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintResult:
|
||||
"""Plaintext token surfaces to the caller once."""
|
||||
|
||||
token: str
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UpsertOutcome:
|
||||
token_id: uuid.UUID
|
||||
rotated: bool
|
||||
old_hash: str | None
|
||||
|
||||
|
||||
def generate_token(prefix: str) -> str:
|
||||
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def mint_oauth_token(
|
||||
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
|
||||
# the wrapper proxies the same execute/commit surface.
|
||||
session: Session | scoped_session,
|
||||
redis_client,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
ttl_days: int,
|
||||
) -> MintResult:
|
||||
"""Live row rotates in place via partial unique index
|
||||
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
|
||||
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
|
||||
deleted so stale AuthContext drops immediately.
|
||||
"""
|
||||
if prefix == PREFIX_OAUTH_ACCOUNT:
|
||||
# Account flow always writes the sentinel — caller may pass None
|
||||
# (for clarity) or the sentinel itself; nothing else is valid.
|
||||
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
|
||||
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
|
||||
subject_issuer = ACCOUNT_ISSUER_SENTINEL
|
||||
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
|
||||
# Defense in depth: enterprise canonicalises + rejects empty,
|
||||
# but a regression there must not yield a NULL composite key here.
|
||||
if not subject_issuer or not subject_issuer.strip():
|
||||
raise ValueError("external-SSO token requires non-empty subject_issuer")
|
||||
else:
|
||||
raise ValueError(f"unknown oauth prefix: {prefix!r}")
|
||||
|
||||
token = generate_token(prefix)
|
||||
new_hash = sha256_hex(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
|
||||
|
||||
outcome = _upsert(
|
||||
session,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
new_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if outcome.rotated and outcome.old_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
|
||||
|
||||
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
|
||||
|
||||
|
||||
def _upsert(
|
||||
session: Session | scoped_session,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
new_hash: str,
|
||||
expires_at: datetime,
|
||||
) -> UpsertOutcome:
|
||||
# Snapshot prior live row's hash for Redis invalidation post-rotate.
|
||||
# subject_issuer is always non-null here (account flow uses sentinel,
|
||||
# external-SSO is validated upstream), so equality matches the index.
|
||||
prior = session.execute(
|
||||
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
|
||||
.where(
|
||||
OAuthAccessToken.subject_email == subject_email,
|
||||
OAuthAccessToken.subject_issuer == subject_issuer,
|
||||
OAuthAccessToken.client_id == client_id,
|
||||
OAuthAccessToken.device_label == device_label,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
old_hash = prior.token_hash if prior else None
|
||||
|
||||
insert_stmt = pg_insert(OAuthAccessToken).values(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
token_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
index_where=OAuthAccessToken.revoked_at.is_(None),
|
||||
set_={
|
||||
"token_hash": insert_stmt.excluded.token_hash,
|
||||
"prefix": insert_stmt.excluded.prefix,
|
||||
"account_id": insert_stmt.excluded.account_id,
|
||||
"expires_at": insert_stmt.excluded.expires_at,
|
||||
"created_at": func.now(),
|
||||
"last_used_at": None,
|
||||
},
|
||||
).returning(OAuthAccessToken.id)
|
||||
row = session.execute(upsert_stmt).first()
|
||||
session.commit()
|
||||
|
||||
if row is None:
|
||||
raise RuntimeError("oauth_token upsert returned no row")
|
||||
token_id = uuid.UUID(str(row.id))
|
||||
return UpsertOutcome(
|
||||
token_id=token_id,
|
||||
rotated=prior is not None,
|
||||
old_hash=old_hash,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TTL policy — days new OAuth tokens live
|
||||
# ============================================================================
|
||||
|
||||
|
||||
DEFAULT_OAUTH_TTL_DAYS = 14
|
||||
MIN_TTL_DAYS = 1
|
||||
MAX_TTL_DAYS = 365
|
||||
|
||||
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
|
||||
|
||||
|
||||
def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
|
||||
is deferred; when it lands it wins over the env (Redis-cached 60s).
|
||||
"""
|
||||
_ = tenant_id
|
||||
|
||||
raw = os.environ.get(_TTL_ENV_VAR)
|
||||
if raw is None:
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"%s=%r is not an int; falling back to %d",
|
||||
_TTL_ENV_VAR,
|
||||
raw,
|
||||
DEFAULT_OAUTH_TTL_DAYS,
|
||||
)
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
if value < MIN_TTL_DAYS:
|
||||
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
|
||||
return MIN_TTL_DAYS
|
||||
if value > MAX_TTL_DAYS:
|
||||
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
|
||||
return MAX_TTL_DAYS
|
||||
return value
|
||||
@ -1,52 +0,0 @@
|
||||
"""License gate for the /openapi/v1/permitted-external-apps* surface.
|
||||
|
||||
EE-only. CE deploys (``ENTERPRISE_ENABLED=false``) skip the gate entirely —
|
||||
the EE blueprint chain is what gives CE deploys no callers on this surface
|
||||
in practice, but the explicit short-circuit avoids any test/fixture that
|
||||
flips the surface on without flipping the license.
|
||||
|
||||
Reuses ``FeatureService.get_system_features()`` so the license status
|
||||
travels the same path as the console reads.
|
||||
|
||||
Companion to ``controllers.console.wraps.enterprise_license_required`` —
|
||||
that one is for console (cookie-authed, force-logout 401). This one is
|
||||
for bearer surface (token-authed, 403 ``license_required``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VALID_LICENSE_STATUSES: frozenset[LicenseStatus] = frozenset({LicenseStatus.ACTIVE, LicenseStatus.EXPIRING})
|
||||
|
||||
|
||||
def license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator form. Raises ``Forbidden('license_required')`` when the EE
|
||||
deployment has no valid license. No-op on CE (``ENTERPRISE_ENABLED=false``).
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if dify_config.ENTERPRISE_ENABLED and not _is_license_valid():
|
||||
raise Forbidden(description="license_required")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def _is_license_valid() -> bool:
|
||||
try:
|
||||
features = FeatureService.get_system_features()
|
||||
except Exception:
|
||||
logger.exception("license_gate: FeatureService.get_system_features failed")
|
||||
return False
|
||||
return features.license.status in _VALID_LICENSE_STATUSES
|
||||
@ -1,47 +0,0 @@
|
||||
"""Hard mint policy.
|
||||
|
||||
``validate_mint_policy`` cross-checks a (subject_type, prefix, scopes)
|
||||
triple a caller intends to mint against ``MINTABLE_PROFILES`` —
|
||||
the single source of truth in ``libs.oauth_bearer``.
|
||||
|
||||
The defense-in-depth value: if a future caller assembles ``prefix`` or
|
||||
``scopes`` from a non-canonical source (env, request body, plug-in
|
||||
contribution), the mismatch fails closed at approve time before any
|
||||
row hits the DB. When the caller reads straight from
|
||||
``MINTABLE_PROFILES``, the check is a structural pin — it confirms the
|
||||
table entry is well-formed and the caller picked the right key.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, Scope, SubjectType
|
||||
|
||||
|
||||
class MintPolicyViolation(Exception): # noqa: N818 — spec-defined name, used in BadRequest message
|
||||
"""Raised on a (subject_type, prefix, scopes) mismatch. Callers translate
|
||||
to 400 ``mint_policy_violation``."""
|
||||
|
||||
|
||||
def validate_mint_policy(
|
||||
*,
|
||||
subject_type: SubjectType,
|
||||
prefix: str,
|
||||
scopes: frozenset[Scope],
|
||||
) -> None:
|
||||
"""Raise ``MintPolicyViolation`` when the triple does not match the
|
||||
canonical ``MINTABLE_PROFILES`` entry for ``subject_type``.
|
||||
"""
|
||||
profile = MINTABLE_PROFILES.get(subject_type)
|
||||
if profile is None:
|
||||
raise MintPolicyViolation(f"mint_policy_violation: unknown subject_type={subject_type!r}")
|
||||
|
||||
drift = []
|
||||
if profile.prefix != prefix:
|
||||
drift.append(f"prefix got={prefix!r} expected={profile.prefix!r}")
|
||||
if frozenset(scopes) != profile.scopes:
|
||||
got = sorted(s.value for s in scopes)
|
||||
want = sorted(s.value for s in profile.scopes)
|
||||
drift.append(f"scopes got={got} expected={want}")
|
||||
|
||||
if drift:
|
||||
raise MintPolicyViolation(f"mint_policy_violation: subject_type={subject_type.value} — " + "; ".join(drift))
|
||||
@ -1,32 +0,0 @@
|
||||
"""Single-source visibility filter for the /openapi/v1/* surface.
|
||||
|
||||
Keep every openapi-surface app query routed through ``_apply_openapi_gate``;
|
||||
retiring or replacing the gate then becomes a one-line change here.
|
||||
|
||||
The Service API (/v1/* app-key surface) does NOT use this helper — that
|
||||
surface has its own per-request guard (``service_api_disabled``) wired
|
||||
into the legacy ``validate_app_token`` decorator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from models.model import App
|
||||
|
||||
|
||||
def apply_openapi_gate(query: Any) -> Any:
|
||||
"""Filter a SQLAlchemy Select/Query to apps visible on /openapi/v1/*.
|
||||
|
||||
Works with both legacy ``Query.filter`` and 2.0-style ``Select.filter``
|
||||
(alias of ``.where``).
|
||||
"""
|
||||
return query.filter(App.enable_api.is_(True))
|
||||
|
||||
|
||||
def is_openapi_visible(app: App) -> bool:
|
||||
"""Per-row counterpart for code paths that fetch an App by primary key
|
||||
(``session.get`` / ``session.scalar``) and need the same visibility check
|
||||
the query gate would have applied.
|
||||
"""
|
||||
return bool(app.enable_api)
|
||||
@ -15,7 +15,7 @@ from models import Account, AccountStatus
|
||||
from models.model import App, EndUser, Site
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import PERMISSION_CHECK_MODES, EnterpriseService, WebAppAccessMode
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||
|
||||
@ -137,8 +137,12 @@ class WebAppAuthService:
|
||||
"""
|
||||
Check if the app requires permission check based on its access mode.
|
||||
"""
|
||||
modes_requiring_permission_check = [
|
||||
"private",
|
||||
"private_all",
|
||||
]
|
||||
if access_mode:
|
||||
return access_mode in PERMISSION_CHECK_MODES
|
||||
return access_mode in modes_requiring_permission_check
|
||||
|
||||
if not app_code and not app_id:
|
||||
raise ValueError("Either app_code or app_id must be provided.")
|
||||
@ -149,7 +153,7 @@ class WebAppAuthService:
|
||||
raise ValueError("App ID could not be determined from the provided app_code.")
|
||||
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if webapp_settings and webapp_settings.access_mode in PERMISSION_CHECK_MODES:
|
||||
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -162,11 +166,11 @@ class WebAppAuthService:
|
||||
raise ValueError("Either app_code or access_mode must be provided.")
|
||||
|
||||
if access_mode:
|
||||
if access_mode == WebAppAccessMode.PUBLIC:
|
||||
if access_mode == "public":
|
||||
return WebAppAuthType.PUBLIC
|
||||
elif access_mode in PERMISSION_CHECK_MODES:
|
||||
elif access_mode in ["private", "private_all"]:
|
||||
return WebAppAuthType.INTERNAL
|
||||
elif access_mode == WebAppAccessMode.SSO_VERIFIED:
|
||||
elif access_mode == "sso_verified":
|
||||
return WebAppAuthType.EXTERNAL
|
||||
|
||||
if app_code:
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
"""Shared fixtures for /openapi/v1/* integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus
|
||||
|
||||
|
||||
def _sha256(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_enterprise(monkeypatch):
|
||||
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
|
||||
EE branch override this with their own monkeypatch in-test."""
|
||||
from configs import dify_config
|
||||
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
|
||||
with flask_app.app_context():
|
||||
tenant = Tenant(name="t1", status="normal")
|
||||
account = Account(email="u@example.com", name="u")
|
||||
db.session.add_all([tenant, account])
|
||||
db.session.commit()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
yield account, tenant, join
|
||||
db.session.delete(join)
|
||||
db.session.delete(account)
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
|
||||
_, tenant, _ = workspace_account
|
||||
with flask_app.app_context():
|
||||
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mint_token(flask_app: Flask):
|
||||
"""Factory fixture; tracks minted rows and deletes them on teardown so
|
||||
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
|
||||
minted: list[OAuthAccessToken] = []
|
||||
|
||||
def _mint(
|
||||
token: str,
|
||||
*,
|
||||
account_id: str | None,
|
||||
prefix: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
) -> OAuthAccessToken:
|
||||
with flask_app.app_context():
|
||||
row = OAuthAccessToken(
|
||||
token_hash=_sha256(token),
|
||||
prefix=prefix,
|
||||
account_id=account_id,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
client_id="difyctl",
|
||||
device_label="test-device",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
minted.append(row)
|
||||
return row
|
||||
|
||||
yield _mint
|
||||
|
||||
with flask_app.app_context():
|
||||
for row in minted:
|
||||
db.session.delete(db.session.merge(row))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account_token(workspace_account, mint_token) -> str:
|
||||
account, _, _ = workspace_account
|
||||
token = "dfoa_" + uuid.uuid4().hex
|
||||
mint_token(
|
||||
token,
|
||||
account_id=account.id,
|
||||
prefix="dfoa_",
|
||||
subject_email=account.email,
|
||||
subject_issuer="dify:account",
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
|
||||
def _flush():
|
||||
with flask_app.app_context():
|
||||
for k in redis_client.keys("auth:*"):
|
||||
redis_client.delete(k)
|
||||
for k in redis_client.keys("rl:*"):
|
||||
redis_client.delete(k)
|
||||
|
||||
_flush()
|
||||
yield
|
||||
_flush()
|
||||
@ -1,238 +0,0 @@
|
||||
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models import App
|
||||
|
||||
|
||||
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
captured["invoke_from"] = invoke_from
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"conversation_id": "c",
|
||||
"mode": "chat",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "chat"
|
||||
assert captured["mode"] == "chat"
|
||||
assert captured["invoke_from"] == InvokeFrom.OPENAPI
|
||||
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_mode(flask_app: Flask, workspace_account):
|
||||
"""Factory that creates an App row in the workspace_account tenant with
|
||||
a specified mode. Tracks rows for teardown.
|
||||
"""
|
||||
_, tenant, _ = workspace_account
|
||||
created: list[App] = []
|
||||
|
||||
def _make(mode: str) -> App:
|
||||
with flask_app.app_context():
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"a-{mode}",
|
||||
mode=mode,
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db.session.refresh(app)
|
||||
db.session.expunge(app)
|
||||
created.append(app)
|
||||
return app
|
||||
|
||||
yield _make
|
||||
|
||||
with flask_app.app_context():
|
||||
for app in created:
|
||||
db.session.delete(db.session.merge(app))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_required_for_chat" in res.data
|
||||
|
||||
|
||||
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("completion")
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"mode": "completion",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "completion"
|
||||
assert captured["mode"] == "completion"
|
||||
|
||||
|
||||
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("workflow")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_not_supported_for_workflow" in res.data
|
||||
|
||||
|
||||
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("workflow")
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
return {
|
||||
"workflow_run_id": "wfr",
|
||||
"task_id": "t",
|
||||
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.get_json()
|
||||
assert body["mode"] == "workflow"
|
||||
assert body["workflow_run_id"] == "wfr"
|
||||
|
||||
|
||||
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("channel")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"mode_not_runnable" in res.data
|
||||
|
||||
|
||||
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
"""Stub the authenticator to return an AuthContext with empty scopes."""
|
||||
from libs import oauth_bearer
|
||||
|
||||
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
|
||||
|
||||
def _stub_authenticate(self, token: str):
|
||||
ctx = real_authenticate(self, token)
|
||||
from dataclasses import replace
|
||||
|
||||
return replace(ctx, scopes=frozenset())
|
||||
|
||||
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{uuid.uuid4()}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
def _stream() -> Generator[str, None, None]:
|
||||
yield 'event: message\ndata: {"x": 1}\n\n'
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate",
|
||||
staticmethod(lambda **kw: _stream()),
|
||||
)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.headers["Content-Type"].startswith("text/event-stream")
|
||||
assert b"event: message" in res.data
|
||||
|
||||
|
||||
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
@ -1,210 +0,0 @@
|
||||
"""Integration tests for /openapi/v1/apps* read surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from models import App
|
||||
|
||||
|
||||
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_describe_returns_merged_shape(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"]["id"] == app_in_workspace.id
|
||||
assert body["info"]["mode"] == "chat"
|
||||
assert isinstance(body["parameters"], dict)
|
||||
|
||||
|
||||
def test_apps_describe_full_includes_input_schema(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is not None
|
||||
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
|
||||
def test_apps_describe_fields_info_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_parameters_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_input_schema_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_combined(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_unknown_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_describe_fields_extra_param_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_list_returns_pagination_envelope(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["page"] == 1
|
||||
assert body["limit"] == 20
|
||||
assert body["total"] >= 1
|
||||
assert any(d["id"] == app_in_workspace.id for d in body["data"])
|
||||
|
||||
|
||||
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
|
||||
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_apps_list_tag_no_match_returns_empty_data_not_400(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["data"] == []
|
||||
|
||||
|
||||
def test_account_sessions_returns_envelope(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
# canonical envelope shape
|
||||
assert isinstance(body["data"], list)
|
||||
assert "page" in body
|
||||
assert "limit" in body
|
||||
assert "total" in body
|
||||
assert "has_more" in body
|
||||
# the bearer's own minted session must appear
|
||||
assert any(s["prefix"] == "dfoa_" for s in body["data"])
|
||||
# legacy "sessions" key must NOT appear
|
||||
assert "sessions" not in body
|
||||
@ -1,127 +0,0 @@
|
||||
"""Integration tests for the /openapi/v1 bearer auth surface.
|
||||
|
||||
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
|
||||
acceptance/rejection on app-scoped routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
def test_info_accepts_account_bearer_with_apps_read_scope(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["id"] == app_in_workspace.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
|
||||
"""A fresh app under a *different* tenant — caller has no membership row."""
|
||||
with flask_app.app_context():
|
||||
other_tenant = Tenant(name="other", status="normal")
|
||||
db.session.add(other_tenant)
|
||||
db.session.commit()
|
||||
app = App(
|
||||
tenant_id=other_tenant.id,
|
||||
name="b",
|
||||
mode="chat",
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.delete(other_tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_layer0_denies_account_bearer_without_membership(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
) -> None:
|
||||
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
assert res.json.get("message") == "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_layer0_skipped_when_enterprise_enabled(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
|
||||
|
||||
/info uses validate_bearer + require_workspace_member inline (no
|
||||
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
|
||||
gets 200 — gateway is expected to enforce isolation upstream.
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
# Override the conftest autouse default for this test only.
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
|
||||
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json.get("message") != "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_rate_limit_returns_429_after_60_requests(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for i in range(60):
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 200, f"unexpected fail at i={i}"
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
assert r.headers.get("Retry-After"), "Retry-After header missing"
|
||||
assert int(r.headers["Retry-After"]) >= 1
|
||||
body = r.json or {}
|
||||
assert body.get("error") == "rate_limited"
|
||||
assert isinstance(body.get("retry_after_ms"), int)
|
||||
assert body["retry_after_ms"] >= 1000
|
||||
|
||||
|
||||
def test_rate_limit_bucket_shared_across_surfaces(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for _ in range(30):
|
||||
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
|
||||
for _ in range(30):
|
||||
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
@ -1,66 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def test_pipeline_is_composed():
|
||||
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
|
||||
|
||||
|
||||
def test_pipeline_step_order():
|
||||
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
|
||||
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
|
||||
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
|
||||
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
|
||||
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
|
||||
steps = OAUTH_BEARER_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], SurfaceCheck)
|
||||
assert isinstance(steps[2], ScopeCheck)
|
||||
assert isinstance(steps[3], AppResolver)
|
||||
assert isinstance(steps[4], WorkspaceMembershipCheck)
|
||||
assert isinstance(steps[5], AppAuthzCheck)
|
||||
assert isinstance(steps[6], CallerMount)
|
||||
|
||||
|
||||
def test_pipeline_surface_check_accepts_account_only():
|
||||
"""Current pipeline serves /apps/<id>/run — account surface only."""
|
||||
surface = OAUTH_BEARER_PIPELINE._steps[1]
|
||||
assert isinstance(surface, SurfaceCheck)
|
||||
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
|
||||
|
||||
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = OAUTH_BEARER_PIPELINE._steps[6]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
@ -1,21 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
@ -1,61 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
pipeline = Pipeline(FakeStep())
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
@ -1,64 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(view_args):
|
||||
req = MagicMock()
|
||||
req.view_args = view_args
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_none_view_args():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -1,75 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_private_calls_inner_api(ent):
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="acc1",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("access_mode", "subject_type", "expected"),
|
||||
[
|
||||
("public", SubjectType.ACCOUNT, True),
|
||||
("public", SubjectType.EXTERNAL_SSO, True),
|
||||
("sso_verified", SubjectType.ACCOUNT, True),
|
||||
("sso_verified", SubjectType.EXTERNAL_SSO, True),
|
||||
("private_all", SubjectType.ACCOUNT, True),
|
||||
("private_all", SubjectType.EXTERNAL_SSO, False),
|
||||
("private", SubjectType.EXTERNAL_SSO, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
|
||||
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
|
||||
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
|
||||
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
|
||||
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
|
||||
def test_membership_strategy_uses_join_lookup(member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with("acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -1,67 +0,0 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
|
||||
|
||||
|
||||
def _ctx(headers):
|
||||
req = MagicMock()
|
||||
req.headers = headers
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context_and_g_auth_ctx(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=tok_id,
|
||||
source="oauth-account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="hash-1",
|
||||
verified_tenants={},
|
||||
)
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
app = Flask(__name__)
|
||||
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
|
||||
with app.test_request_context():
|
||||
BearerCheck()(ctx)
|
||||
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
# BearerCheck must also publish the same identity on `g.auth_ctx`
|
||||
# so the surface gate + downstream handlers don't see two
|
||||
# different identity sources between the decorator + pipeline paths.
|
||||
assert g.auth_ctx is authn
|
||||
assert g.auth_ctx.client_id == "difyctl"
|
||||
@ -1,157 +0,0 @@
|
||||
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(request=MagicMock(), required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
c.cached_verified_tenants = cached_verified_tenants
|
||||
c.token_hash = token_hash
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def step():
|
||||
return WorkspaceMembershipCheck()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id=str(uuid.uuid4()),
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
account_id=None,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": True},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": False},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_record.assert_called_once_with("hash-1", "t1", True)
|
||||
@ -1,77 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -1,27 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(request=MagicMock(), required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
@ -1,181 +0,0 @@
|
||||
"""Surface gate tests.
|
||||
|
||||
The gate has two attachment forms — decorator (`accept_subjects`) and
|
||||
pipeline step (`SurfaceCheck`) — and both must:
|
||||
- 403 on mismatched subject type with a canonical-path hint
|
||||
- emit `openapi.wrong_surface_denied` once with the right payload
|
||||
- pass-through on match
|
||||
- raise RuntimeError (not 403) if g.auth_ctx is missing — that's a
|
||||
wiring bug, not a user-driven failure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
|
||||
def _account_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=uuid.uuid4(),
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def _sso_ctx() -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="h2",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_surface — shared core
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_check_surface_passes_when_subject_in_accepted():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
|
||||
|
||||
|
||||
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/permitted-external-apps"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
assert "wrong_surface" in exc.value.description
|
||||
# canonical-path hint should point at the caller's surface,
|
||||
# not the surface they were rejected from
|
||||
assert "/openapi/v1/apps" in exc.value.description
|
||||
emit.assert_called_once()
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
|
||||
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
|
||||
assert kwargs["client_id"] == "difyctl"
|
||||
assert kwargs["token_id"] is not None
|
||||
|
||||
|
||||
def test_check_surface_rejects_sso_on_account_surface():
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
g.auth_ctx = _sso_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
kwargs = emit.call_args.kwargs
|
||||
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
|
||||
|
||||
|
||||
def test_check_surface_runtime_error_when_g_auth_ctx_missing():
|
||||
"""Missing g.auth_ctx means the bearer layer didn't run — wiring bug,
|
||||
not a user-driven failure. Surface as RuntimeError (loud) so a future
|
||||
refactor doesn't accidentally let a route skip authentication and
|
||||
return a 403 that looks identical to a legitimate wrong-surface deny.
|
||||
"""
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps"):
|
||||
with pytest.raises(RuntimeError):
|
||||
check_surface(frozenset({SubjectType.ACCOUNT}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @accept_subjects — decorator form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/account-only")
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def _account_only():
|
||||
return "ok"
|
||||
|
||||
@app.route("/external-only")
|
||||
@accept_subjects(SubjectType.EXTERNAL_SSO)
|
||||
def _external_only():
|
||||
return "ok"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_passes_on_match():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/account-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
# Re-route through the decorated function by reaching for view_function
|
||||
view = app.view_functions["_account_only"]
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_accept_subjects_decorator_403_on_miss():
|
||||
app = _make_app()
|
||||
with app.test_request_context("/external-only"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
view = app.view_functions["_external_only"]
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
|
||||
with pytest.raises(Forbidden):
|
||||
view()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SurfaceCheck — pipeline step form
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _pipeline_ctx() -> Context:
|
||||
req = MagicMock()
|
||||
req.path = "/openapi/v1/apps/<id>/run"
|
||||
return Context(request=req, required_scope=Scope.APPS_RUN)
|
||||
|
||||
|
||||
def test_surface_check_passes_on_match():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
step(_pipeline_ctx()) # no raise
|
||||
|
||||
|
||||
def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/openapi/v1/apps/x/run"):
|
||||
g.auth_ctx = _account_ctx()
|
||||
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
emit.assert_called_once()
|
||||
@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
|
||||
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
|
||||
pipeline at import time; mocking the module attribute does not undo
|
||||
that. Patching Pipeline.run on the class is the bypass that actually
|
||||
works.
|
||||
"""
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
a = Flask(__name__)
|
||||
a.config["TESTING"] = True
|
||||
return a
|
||||
@ -1,140 +0,0 @@
|
||||
"""User-scoped identity + session endpoints under /openapi/v1/account."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.account import (
|
||||
AccountApi,
|
||||
AccountSessionByIdApi,
|
||||
AccountSessionsApi,
|
||||
AccountSessionsSelfApi,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_account_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account" in rules
|
||||
|
||||
|
||||
def test_account_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountApi
|
||||
|
||||
|
||||
def test_account_sessions_self_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/self" in rules
|
||||
|
||||
|
||||
def test_sessions_self_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsSelfApi
|
||||
|
||||
|
||||
def test_account_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_self_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_list_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions" in rules
|
||||
|
||||
|
||||
def test_sessions_list_dispatches_to_sessions_api(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsApi
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_session_by_id_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/<string:session_id>" in rules
|
||||
|
||||
|
||||
def test_session_by_id_dispatches_to_correct_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/<string:session_id>")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionByIdApi
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_subject_match_for_account_filters_by_account_id():
|
||||
"""Account subject scopes queries via account_id."""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
aid = _uuid.uuid4()
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=aid,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
# One predicate, on account_id
|
||||
assert len(clauses) == 1
|
||||
assert "account_id" in str(clauses[0])
|
||||
|
||||
|
||||
def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
"""External SSO subject scopes via (subject_email, subject_issuer)
|
||||
AND account_id IS NULL — so a same-email account row from a
|
||||
federated tenant cannot be revoked through an SSO bearer.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
client_id="difyctl",
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
assert len(clauses) == 3
|
||||
rendered = " ".join(str(c) for c in clauses)
|
||||
assert "subject_email" in rendered
|
||||
assert "subject_issuer" in rendered
|
||||
assert "account_id IS NULL" in rendered
|
||||
@ -1,48 +0,0 @@
|
||||
"""Unit tests for AppDescribeQuery (`?fields=` allow-list)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps import AppDescribeQuery
|
||||
|
||||
|
||||
def test_no_fields_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": ""})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_single_field() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info"})
|
||||
assert q.fields == {"info"}
|
||||
|
||||
|
||||
def test_comma_list() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info,parameters"})
|
||||
assert q.fields == {"info", "parameters"}
|
||||
|
||||
|
||||
def test_whitespace_tolerant() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": " info , input_schema "})
|
||||
assert q.fields == {"info", "input_schema"}
|
||||
|
||||
|
||||
def test_unknown_member_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "garbage"})
|
||||
|
||||
|
||||
def test_unknown_among_known_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info,garbage"})
|
||||
|
||||
|
||||
def test_extra_param_forbidden() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info", "page": "1"})
|
||||
@ -1,105 +0,0 @@
|
||||
"""Unit tests for AppListQuery — the /apps query-param validator.
|
||||
|
||||
Runs against the model directly, not the HTTP layer. Pins:
|
||||
- defaults match the plan (page=1, limit=20).
|
||||
- workspace_id is required.
|
||||
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
|
||||
- mode validates against the AppMode enum.
|
||||
- name and tag have length caps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT
|
||||
from controllers.openapi.apps import AppListQuery
|
||||
|
||||
|
||||
def test_defaults():
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
assert q.tag is None
|
||||
|
||||
|
||||
def test_workspace_id_required():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({})
|
||||
|
||||
|
||||
def test_page_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": -1})
|
||||
|
||||
|
||||
def test_page_rejects_non_integer_string():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": "abc"})
|
||||
|
||||
|
||||
def test_limit_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": -1})
|
||||
|
||||
|
||||
def test_limit_caps_at_max_page_limit():
|
||||
# Boundary accepts.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT})
|
||||
assert q.limit == MAX_PAGE_LIMIT
|
||||
|
||||
# Just over rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT + 1})
|
||||
|
||||
|
||||
def test_mode_whitelisted_against_app_mode():
|
||||
# Valid mode passes.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
|
||||
# Invalid mode rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_name_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 200})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 201})
|
||||
|
||||
|
||||
def test_tag_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 100})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 101})
|
||||
|
||||
|
||||
def test_all_fields_accept_valid_values():
|
||||
"""Pin the happy-path acceptance for every field in one place."""
|
||||
q = AppListQuery.model_validate(
|
||||
{
|
||||
"workspace_id": "ws-1",
|
||||
"page": 5,
|
||||
"limit": 50,
|
||||
"mode": "workflow",
|
||||
"name": "search",
|
||||
"tag": "prod",
|
||||
}
|
||||
)
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 5
|
||||
assert q.limit == 50
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "workflow"
|
||||
assert q.name == "search"
|
||||
assert q.tag == "prod"
|
||||
@ -1,55 +0,0 @@
|
||||
"""Unit tests for app payload-rendering helpers — independent of
|
||||
HTTP plumbing or DB. Pin the response shapes that are CLI contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.apps import ( # pyright: ignore[reportPrivateUsage]
|
||||
_EMPTY_PARAMETERS,
|
||||
parameters_payload,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
|
||||
|
||||
def _fake_app(**overrides):
|
||||
base = {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author_name": "alice",
|
||||
"tags": [SimpleNamespace(name="prod")],
|
||||
"updated_at": None,
|
||||
"enable_api": True,
|
||||
"workflow": None,
|
||||
"app_model_config": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_parameters_payload_raises_app_unavailable_when_no_config():
|
||||
with pytest.raises(AppUnavailableError):
|
||||
parameters_payload(_fake_app(mode="chat", app_model_config=None))
|
||||
|
||||
|
||||
def test_empty_parameters_constant_matches_describe_fallback_shape():
|
||||
"""The fallback dict served by /describe when an app has no config
|
||||
must match the spec's stated keys (opening_statement, suggested_questions,
|
||||
user_input_form, file_upload, system_parameters)."""
|
||||
assert set(_EMPTY_PARAMETERS.keys()) == {
|
||||
"opening_statement",
|
||||
"suggested_questions",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"system_parameters",
|
||||
}
|
||||
assert _EMPTY_PARAMETERS["suggested_questions"] == []
|
||||
assert _EMPTY_PARAMETERS["user_input_form"] == []
|
||||
assert _EMPTY_PARAMETERS["opening_statement"] is None
|
||||
assert _EMPTY_PARAMETERS["file_upload"] is None
|
||||
assert _EMPTY_PARAMETERS["system_parameters"] == {}
|
||||
@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.app_run import (
|
||||
_DISPATCH,
|
||||
AppRunRequest,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_dispatch_covers_runnable_modes():
|
||||
runnable = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW}
|
||||
assert set(_DISPATCH) == runnable
|
||||
|
||||
|
||||
def test_app_run_request_strips_blank_conversation_id():
|
||||
payload = AppRunRequest(inputs={}, conversation_id=" ")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_app_run_request_rejects_invalid_uuid_conversation_id():
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="conversation_id must be a valid UUID"):
|
||||
AppRunRequest(inputs={}, conversation_id="not-a-uuid")
|
||||
|
||||
|
||||
def test_app_run_request_accepts_valid_uuid_conversation_id():
|
||||
import uuid as _uuid
|
||||
|
||||
cid = str(_uuid.uuid4())
|
||||
payload = AppRunRequest(inputs={}, conversation_id=cid)
|
||||
assert payload.conversation_id == cid
|
||||
@ -1,85 +0,0 @@
|
||||
"""Tests: openapi /run always streams; response_mode removed from AppRunRequest."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
|
||||
|
||||
def test_app_run_request_has_no_response_mode_field():
|
||||
"""response_mode must not be a declared field."""
|
||||
assert "response_mode" not in AppRunRequest.model_fields
|
||||
|
||||
|
||||
def test_app_run_request_ignores_response_mode_in_payload():
|
||||
"""Sending response_mode in JSON body is silently ignored (Pydantic extra='ignore')."""
|
||||
req = AppRunRequest.model_validate({"inputs": {}, "response_mode": "blocking"})
|
||||
assert not hasattr(req, "response_mode")
|
||||
|
||||
|
||||
def test_app_run_request_valid_minimal():
|
||||
req = AppRunRequest.model_validate({"inputs": {}})
|
||||
assert req.inputs == {}
|
||||
|
||||
|
||||
def test_app_run_request_with_query():
|
||||
req = AppRunRequest.model_validate({"inputs": {}, "query": "hello"})
|
||||
assert req.query == "hello"
|
||||
|
||||
|
||||
def test_run_chat_always_calls_generate_with_streaming_true(app, bypass_pipeline, monkeypatch):
|
||||
"""_run_chat must always invoke AppGenerateService.generate with streaming=True."""
|
||||
from controllers.openapi.app_run import _run_chat
|
||||
|
||||
generate_mock = Mock(return_value=iter([]))
|
||||
monkeypatch.setattr(
|
||||
sys.modules["controllers.openapi.app_run"],
|
||||
"AppGenerateService",
|
||||
SimpleNamespace(generate=generate_mock),
|
||||
)
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/run", method="POST"):
|
||||
_run_chat(
|
||||
SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
SimpleNamespace(id="acct-1"),
|
||||
AppRunRequest(inputs={}, query="hello"),
|
||||
)
|
||||
_, kwargs = generate_mock.call_args
|
||||
assert kwargs["streaming"] is True
|
||||
|
||||
|
||||
def test_stop_task_endpoint_registered(openapi_app):
|
||||
"""POST /openapi/v1/apps/<id>/tasks/<task_id>/stop must be registered."""
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/apps/<string:app_id>/tasks/<string:task_id>/stop" in rules
|
||||
|
||||
|
||||
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.app_run import AppRunTaskStopApi
|
||||
|
||||
queue_mock = Mock()
|
||||
graph_mock = Mock()
|
||||
graph_instance = Mock()
|
||||
graph_mock.return_value = graph_instance
|
||||
|
||||
run_module = sys.modules["controllers.openapi.app_run"]
|
||||
monkeypatch.setattr(run_module, "AppQueueManager", queue_mock)
|
||||
monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock)
|
||||
monkeypatch.setattr(run_module, "redis_client", object())
|
||||
|
||||
api = AppRunTaskStopApi()
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
task_id="task-1",
|
||||
app_model=SimpleNamespace(id="app-1", tenant_id="t-1"),
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")
|
||||
graph_instance.send_stop_command.assert_called_once_with("task-1")
|
||||
assert result == {"result": "success"}
|
||||
@ -1,53 +0,0 @@
|
||||
"""Unit tests for PermittedExternalAppsListQuery — the
|
||||
/permitted-external-apps query validator.
|
||||
|
||||
Strict ConfigDict(extra='forbid'): cross-tenant tag/workspace_id are
|
||||
unresolvable, so the model must reject them as 422 instead of silently
|
||||
dropping them. Mode/name/page/limit have the same shape as AppListQuery.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps_permitted_external import PermittedExternalAppsListQuery
|
||||
|
||||
|
||||
def test_query_defaults_match_apps_list():
|
||||
q = PermittedExternalAppsListQuery.model_validate({})
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
|
||||
|
||||
def test_query_rejects_workspace_id():
|
||||
"""workspace_id is meaningless for /permitted-external-apps (cross-tenant);
|
||||
rejecting it forces CLI authors to drop the param rather than send it
|
||||
silently."""
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
|
||||
|
||||
def test_query_rejects_tag():
|
||||
"""Tags are tenant-scoped; cross-tenant tag resolution is undefined."""
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"tag": "prod"})
|
||||
|
||||
|
||||
def test_query_validates_mode_against_app_mode():
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_query_clamps_limit_at_max():
|
||||
with pytest.raises(ValidationError):
|
||||
PermittedExternalAppsListQuery.model_validate({"limit": 500})
|
||||
|
||||
|
||||
def test_query_accepts_valid_mode():
|
||||
"""Pin the happy path: AppMode values pass."""
|
||||
q = PermittedExternalAppsListQuery.model_validate({"mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
@ -1,26 +0,0 @@
|
||||
import logging
|
||||
|
||||
from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run
|
||||
|
||||
|
||||
def test_event_constant():
|
||||
assert EVENT_APP_RUN_OPENAPI == "app.run.openapi"
|
||||
|
||||
|
||||
def test_emit_app_run_logs_with_audit_extra(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"):
|
||||
emit_app_run(
|
||||
app_id="app1",
|
||||
tenant_id="t1",
|
||||
caller_kind="account",
|
||||
mode="chat",
|
||||
surface="apps",
|
||||
)
|
||||
record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message)
|
||||
assert record.audit is True
|
||||
assert record.event == EVENT_APP_RUN_OPENAPI
|
||||
assert record.app_id == "app1"
|
||||
assert record.tenant_id == "t1"
|
||||
assert record.caller_kind == "account"
|
||||
assert record.mode == "chat"
|
||||
assert record.surface == "apps"
|
||||
@ -1,127 +0,0 @@
|
||||
"""CORS posture for /openapi/v1/* — default empty allowlist (same-origin),
|
||||
expandable via OPENAPI_CORS_ALLOW_ORIGINS. Cross-origin requests from
|
||||
disallowed origins do not receive the Access-Control-Allow-Origin
|
||||
header, which the browser then blocks.
|
||||
|
||||
Tests use a fresh Blueprint + Flask-CORS per case because the production
|
||||
blueprint is a module-level singleton and can't be reconfigured once
|
||||
registered.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask.views import MethodView
|
||||
from flask_cors import CORS
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_blueprints import OPENAPI_HEADERS, OPENAPI_MAX_AGE_SECONDS
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _make_app(allowed_origins: list[str], blueprint_name: str) -> Flask:
|
||||
"""Build a Flask app with a fresh openapi-style blueprint mirroring
|
||||
production CORS settings, parameterised on the origin allowlist.
|
||||
"""
|
||||
bp = Blueprint(blueprint_name, __name__, url_prefix="/openapi/v1")
|
||||
api = ExternalApi(bp, version="1.0", title="OpenAPI Test", description="")
|
||||
|
||||
@api.route("/_health")
|
||||
class _Health(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
|
||||
CORS(
|
||||
bp,
|
||||
resources={r"/*": {"origins": allowed_origins}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=["X-Version"],
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_default_openapi_cors_allowlist_is_empty():
|
||||
"""Default config admits no cross-origin until operator opts in."""
|
||||
assert dify_config.OPENAPI_CORS_ALLOW_ORIGINS == []
|
||||
|
||||
|
||||
def test_preflight_allowed_origin_returns_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t1")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "https://app.example.com"
|
||||
assert response.headers.get("Access-Control-Max-Age") == str(OPENAPI_MAX_AGE_SECONDS)
|
||||
|
||||
|
||||
def test_preflight_disallowed_origin_omits_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t2")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://attacker.example",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
# flask-cors omits Allow-Origin for disallowed origins; browser blocks.
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_preflight_with_default_empty_allowlist_omits_cors_headers():
|
||||
app = _make_app([], "openapi_t3")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_same_origin_request_succeeds_without_origin_header():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t4")
|
||||
client = app.test_client()
|
||||
# Browsers don't send Origin on same-origin GETs.
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_authorization_header_is_in_allow_headers():
|
||||
"""Bearer-authed routes need Authorization in the preflight response."""
|
||||
app = _make_app(["https://app.example.com"], "openapi_t5")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
"Access-Control-Request-Headers": "Authorization",
|
||||
},
|
||||
)
|
||||
|
||||
allow_headers = response.headers.get("Access-Control-Allow-Headers", "").lower()
|
||||
assert "authorization" in allow_headers
|
||||
@ -1,52 +0,0 @@
|
||||
"""Account-branch device-flow approve/deny under /openapi/v1."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import DeviceApproveApi, DeviceDenyApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_approve_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve" in rules
|
||||
|
||||
|
||||
def test_deny_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/deny" in rules
|
||||
|
||||
|
||||
def test_approve_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceApproveApi
|
||||
|
||||
|
||||
def test_deny_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceDenyApi
|
||||
|
||||
|
||||
def test_approve_and_deny_methods(openapi_app: Flask):
|
||||
approve = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
deny = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert "POST" in approve.methods
|
||||
assert "POST" in deny.methods
|
||||
@ -1,47 +0,0 @@
|
||||
"""POST /openapi/v1/oauth/device/code is the canonical RFC 8628 device
|
||||
authorization endpoint.
|
||||
|
||||
Tests verify URL routing without invoking the handler — invoking would
|
||||
require Redis, which the unit-test runtime does not initialise.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceCodeApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/code" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
|
||||
|
||||
|
||||
def test_route_accepts_post(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert "POST" in rule.methods
|
||||
|
||||
|
||||
def test_known_client_ids_default_includes_difyctl():
|
||||
from configs import dify_config
|
||||
|
||||
assert "difyctl" in dify_config.OPENAPI_KNOWN_CLIENT_IDS
|
||||
@ -1,36 +0,0 @@
|
||||
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceLookupApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/lookup" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
|
||||
|
||||
|
||||
def test_route_accepts_get(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert "GET" in rule.methods
|
||||
@ -1,105 +0,0 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device_sso import (
|
||||
_email_belongs_to_dify_account,
|
||||
approval_context,
|
||||
approve_external,
|
||||
sso_complete,
|
||||
sso_initiate,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_sso_initiate_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-initiate" in rules
|
||||
|
||||
|
||||
def test_sso_complete_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-complete" in rules
|
||||
|
||||
|
||||
def test_approval_context_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approval-context" in rules
|
||||
|
||||
|
||||
def test_approve_external_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve-external" in rules
|
||||
|
||||
|
||||
def test_sso_initiate_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-initiate")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_initiate
|
||||
|
||||
|
||||
def test_sso_complete_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-complete")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_complete
|
||||
|
||||
|
||||
def test_approval_context_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approval-context")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approval_context
|
||||
|
||||
|
||||
def test_approve_external_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve-external")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approve_external
|
||||
|
||||
|
||||
def test_sso_complete_idp_callback_url_uses_canonical_path():
|
||||
"""sso_initiate hardcodes the IdP callback URL — must point at the
|
||||
canonical /openapi/v1/ path so IdP-side ACS configuration matches.
|
||||
"""
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("email", "row", "expected"),
|
||||
[
|
||||
("alice@example.com", "acc1", True),
|
||||
("alice@example.com", None, False),
|
||||
("Alice@Example.COM", "acc1", True), # case-insensitive lookup
|
||||
(" alice@example.com ", "acc1", True), # surrounding whitespace stripped
|
||||
("", "acc1", False),
|
||||
(" ", "acc1", False),
|
||||
("", None, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.openapi.oauth_device_sso.db")
|
||||
def test_email_belongs_to_dify_account(db_mock, email, row, expected):
|
||||
exec_result = MagicMock()
|
||||
exec_result.scalar_one_or_none.return_value = row
|
||||
db_mock.session.execute.return_value = exec_result
|
||||
assert _email_belongs_to_dify_account(email) is expected
|
||||
if email.strip():
|
||||
db_mock.session.execute.assert_called_once()
|
||||
else:
|
||||
db_mock.session.execute.assert_not_called()
|
||||
@ -1,31 +0,0 @@
|
||||
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceTokenApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/token" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi
|
||||
@ -1,33 +0,0 @@
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_health_returns_ok(app: Flask):
|
||||
client = app.test_client()
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_health_path_is_under_openapi_v1_prefix(app: Flask):
|
||||
client = app.test_client()
|
||||
assert client.get("/_health").status_code == 404
|
||||
assert client.get("/v1/_health").status_code == 404
|
||||
assert client.get("/openapi/v1/_health").status_code == 200
|
||||
@ -1,227 +0,0 @@
|
||||
"""Tests for openapi human input form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
class TestOpenApiHumanInputFormGet:
|
||||
def test_get_success(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
definition = SimpleNamespace(
|
||||
model_dump=lambda: {
|
||||
"rendered_content": "Fill out the form",
|
||||
"inputs": [{"output_variable_name": "field1"}],
|
||||
"default_values": {"field1": "default"},
|
||||
"user_actions": [{"id": "submit", "title": "Submit"}],
|
||||
}
|
||||
)
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
get_definition=lambda: definition,
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
service_mock.ensure_form_active = Mock()
|
||||
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
resp = api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
payload = json.loads(resp.get_data(as_text=True))
|
||||
assert payload["form_content"] == "Fill out the form"
|
||||
assert payload["resolved_default_values"] == {"field1": "default"}
|
||||
assert payload["user_actions"] == [{"id": "submit", "title": "Submit"}]
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
|
||||
def test_get_form_not_found(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = None
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
|
||||
with pytest.raises(NotFound):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="bad",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = SimpleNamespace(
|
||||
app_id="other-app", tenant_id="tenant-1", expiration_time=datetime(2099, 1, 1, tzinfo=UTC)
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
|
||||
with pytest.raises(NotFound):
|
||||
api.get.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=SimpleNamespace(id="acct-1"),
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
|
||||
class TestOpenApiHumanInputFormPost:
|
||||
def _make_form(self, app_id="app-1", recipient_type=RecipientType.STANDALONE_WEB_APP):
|
||||
return SimpleNamespace(
|
||||
app_id=app_id,
|
||||
tenant_id="tenant-1",
|
||||
recipient_type=recipient_type,
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
def test_post_account_caller_uses_user_id(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = self._make_form()
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="acct-42")
|
||||
|
||||
with app.test_request_context(
|
||||
"/openapi/v1/apps/app-1/form/human_input/tok-1",
|
||||
method="POST",
|
||||
json={"action": "approve", "inputs": {"field1": "val"}},
|
||||
):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="account",
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="tok-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"field1": "val"},
|
||||
submission_user_id="acct-42",
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
assert result == ({}, 200)
|
||||
|
||||
def test_post_end_user_caller_uses_end_user_id(self, app, bypass_pipeline, monkeypatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = self._make_form()
|
||||
service_mock = Mock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
|
||||
module = sys.modules["controllers.openapi.human_input_form"]
|
||||
monkeypatch.setattr(module, "HumanInputService", lambda _engine: service_mock)
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = OpenApiWorkflowHumanInputFormApi()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
caller = SimpleNamespace(id="eu-7")
|
||||
|
||||
with app.test_request_context(
|
||||
"/openapi/v1/apps/app-1/form/human_input/tok-1",
|
||||
method="POST",
|
||||
json={"action": "approve", "inputs": {}},
|
||||
):
|
||||
result = api.post.__wrapped__(
|
||||
api,
|
||||
app_id="app-1",
|
||||
form_token="tok-1",
|
||||
app_model=app_model,
|
||||
caller=caller,
|
||||
caller_kind="end_user",
|
||||
)
|
||||
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="tok-1",
|
||||
selected_action_id="approve",
|
||||
form_data={},
|
||||
submission_user_id=None,
|
||||
submission_end_user_id="eu-7",
|
||||
)
|
||||
assert result == ({}, 200)
|
||||
@ -1,182 +0,0 @@
|
||||
"""Unit tests for input_schema derivation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi._input_schema import _form_to_jsonschema
|
||||
|
||||
|
||||
def _wrap(component: dict) -> list[dict]:
|
||||
"""user_input_form rows are single-key dicts: {"text-input": {...}}."""
|
||||
return [component]
|
||||
|
||||
|
||||
def test_text_input_required() -> None:
|
||||
form = _wrap({"text-input": {"variable": "industry", "label": "Industry", "required": True, "max_length": 200}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"industry": {"type": "string", "title": "Industry", "maxLength": 200}}
|
||||
assert required == ["industry"]
|
||||
|
||||
|
||||
def test_paragraph_optional() -> None:
|
||||
form = _wrap({"paragraph": {"variable": "context", "label": "Context", "required": False, "max_length": 4000}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["context"] == {"type": "string", "title": "Context", "maxLength": 4000}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_select_enum() -> None:
|
||||
form = _wrap(
|
||||
{
|
||||
"select": {
|
||||
"variable": "tier",
|
||||
"label": "Tier",
|
||||
"required": True,
|
||||
"options": ["free", "pro", "enterprise"],
|
||||
}
|
||||
}
|
||||
)
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"tier": {"type": "string", "title": "Tier", "enum": ["free", "pro", "enterprise"]}}
|
||||
assert required == ["tier"]
|
||||
|
||||
|
||||
def test_number() -> None:
|
||||
form = _wrap({"number": {"variable": "count", "label": "Count", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["count"] == {"type": "number", "title": "Count"}
|
||||
|
||||
|
||||
def test_file() -> None:
|
||||
form = _wrap({"file": {"variable": "doc", "label": "Doc", "required": True}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["doc"]["type"] == "object"
|
||||
assert "title" in props["doc"]
|
||||
assert required == ["doc"]
|
||||
|
||||
|
||||
def test_file_list() -> None:
|
||||
form = _wrap({"file-list": {"variable": "attachments", "label": "Attachments", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["attachments"]["type"] == "array"
|
||||
assert props["attachments"]["items"]["type"] == "object"
|
||||
|
||||
|
||||
def test_unknown_type_skipped() -> None:
|
||||
"""Forward-compat: unknown variable types are skipped, not 500'd."""
|
||||
form = _wrap({"future-type": {"variable": "x", "label": "X", "required": False}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_required_order_preserved() -> None:
|
||||
form = [
|
||||
{"text-input": {"variable": "a", "label": "A", "required": True}},
|
||||
{"text-input": {"variable": "b", "label": "B", "required": False}},
|
||||
{"text-input": {"variable": "c", "label": "C", "required": True}},
|
||||
]
|
||||
_props, required = _form_to_jsonschema(form)
|
||||
assert required == ["a", "c"]
|
||||
|
||||
|
||||
def test_max_length_omitted_when_zero() -> None:
|
||||
form = _wrap({"text-input": {"variable": "x", "label": "X", "required": False, "max_length": 0}})
|
||||
props, _ = _form_to_jsonschema(form)
|
||||
assert "maxLength" not in props["x"]
|
||||
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _stub_app(mode: AppMode, *, form: list[dict] | None = None, has_workflow: bool | None = None):
|
||||
"""Returns a MagicMock whose .mode + workflow / app_model_config branch is wired up."""
|
||||
app = MagicMock()
|
||||
app.mode = mode
|
||||
if mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
|
||||
if has_workflow is False:
|
||||
app.workflow = None
|
||||
else:
|
||||
app.workflow = MagicMock()
|
||||
app.workflow.user_input_form.return_value = form or []
|
||||
app.workflow.features_dict = {}
|
||||
else:
|
||||
if has_workflow is False:
|
||||
app.app_model_config = None
|
||||
else:
|
||||
app.app_model_config = MagicMock()
|
||||
app.app_model_config.to_dict.return_value = {"user_input_form": form or []}
|
||||
return app
|
||||
|
||||
|
||||
def test_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.CHAT, form=[{"text-input": {"variable": "x", "label": "X", "required": True}}])
|
||||
schema = build_input_schema(app)
|
||||
assert schema["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
assert "query" in schema["properties"]
|
||||
assert schema["properties"]["query"]["type"] == "string"
|
||||
assert schema["properties"]["query"]["minLength"] == 1
|
||||
assert "query" in schema["required"]
|
||||
assert "inputs" in schema["required"]
|
||||
assert schema["properties"]["inputs"]["additionalProperties"] is False
|
||||
|
||||
|
||||
def test_agent_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.AGENT_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_advanced_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.ADVANCED_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_workflow_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_completion_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.COMPLETION, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_inputs_required_driven_by_form() -> None:
|
||||
app = _stub_app(
|
||||
AppMode.CHAT,
|
||||
form=[
|
||||
{"text-input": {"variable": "industry", "label": "Industry", "required": True}},
|
||||
{"text-input": {"variable": "context", "label": "Context", "required": False}},
|
||||
],
|
||||
)
|
||||
schema = build_input_schema(app)
|
||||
assert schema["properties"]["inputs"]["required"] == ["industry"]
|
||||
|
||||
|
||||
def test_misconfigured_chat_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.CHAT, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_misconfigured_workflow_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_empty_input_schema_sentinel_shape() -> None:
|
||||
assert EMPTY_INPUT_SCHEMA["type"] == "object"
|
||||
assert EMPTY_INPUT_SCHEMA["properties"] == {}
|
||||
assert EMPTY_INPUT_SCHEMA["required"] == []
|
||||
@ -1,31 +0,0 @@
|
||||
from controllers.openapi._models import MessageMetadata, UsageInfo
|
||||
|
||||
|
||||
def test_usage_info_defaults_zero():
|
||||
u = UsageInfo()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
|
||||
|
||||
def test_message_metadata_accepts_partial():
|
||||
m = MessageMetadata(usage=UsageInfo(total_tokens=10))
|
||||
assert m.usage.total_tokens == 10
|
||||
assert m.retriever_resources == []
|
||||
|
||||
|
||||
def test_describe_response_all_blocks_optional() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
payload = AppDescribeResponse().model_dump(mode="json", exclude_none=False)
|
||||
assert payload == {"info": None, "parameters": None, "input_schema": None}
|
||||
|
||||
|
||||
def test_describe_response_input_schema_field() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
schema = {"$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object"}
|
||||
payload = AppDescribeResponse(input_schema=schema).model_dump(mode="json", exclude_none=False)
|
||||
assert payload["input_schema"] == schema
|
||||
assert payload["info"] is None
|
||||
assert payload["parameters"] is None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user