Compare commits

..

76 Commits

Author SHA1 Message Date
8d7ee1d761 Merge branch 'main' into feat/cli 2026-05-23 14:37:18 +08:00
a831920803 fix typing 2026-05-23 14:21:04 +08:00
98de360447 refactor: move db query from api leyer to service layer 2026-05-23 14:21:04 +08:00
95816a26b8 refactor(workspace): move data access to service 2026-05-23 14:21:04 +08:00
f39e7d6cd5 refactor: move select to data access layer 2026-05-23 14:21:04 +08:00
a6970bc144 [autofix.ci] apply automated fixes 2026-05-23 03:14:43 +00:00
fecdef6c21 fix test 2026-05-23 11:10:09 +08:00
b7a2347291 fmt 2026-05-23 10:33:07 +08:00
0c1b37687f refactor: decouple Context from flask 2026-05-23 10:33:07 +08:00
341a82bf1e [autofix.ci] apply automated fixes 2026-05-23 01:37:21 +00:00
e71df18d72 add test for openapi registration 2026-05-23 09:28:31 +08:00
152f916768 remove unused export 2026-05-22 22:58:54 +08:00
9b3b408849 fix web typecheck 2026-05-22 22:42:32 +08:00
102643e060 Merge branch 'main' into feat/cli 2026-05-22 22:25:00 +08:00
4c2ba50dfe Merge remote-tracking branch 'upstream/main' into feat/cli 2026-05-22 21:38:55 +08:00
3df1042706 fix api test 2026-05-22 21:38:17 +08:00
0f39ac8960 fix web lint 2026-05-22 18:24:24 +08:00
102a9f3eb3 [autofix.ci] apply automated fixes 2026-05-22 10:20:08 +00:00
d94e302045 fix typings 2026-05-22 18:15:28 +08:00
2ff07b6311 mysql support in migration 2026-05-22 17:38:03 +08:00
1554d80df5 [autofix.ci] apply automated fixes 2026-05-22 09:20:07 +00:00
7ec50f4656 fix typecheck && migration 2026-05-22 17:15:27 +08:00
66c4b9d589 remove docker build for current stage 2026-05-22 16:24:48 +08:00
cb218f2832 add permission constraint to smoke test 2026-05-22 16:17:55 +08:00
ed6a079582 [autofix.ci] apply automated fixes 2026-05-22 08:16:10 +00:00
f1d68e4178 use softprops/action-gh-release for release 2026-05-22 16:11:33 +08:00
851bf36f24 Merge remote-tracking branch 'upstream/main' into feat/cli 2026-05-22 16:07:59 +08:00
f6e4d558a6 Revert "test(cli): add integration test suite for Discovery, Run, Output, Error Handling and CLI Framework"
This reverts commit c38c5d375e.
2026-05-22 15:53:38 +08:00
c38c5d375e test(cli): add integration test suite for Discovery, Run, Output, Error Handling and CLI Framework
Add comprehensive integration tests under cli/test/testcases/ covering:

Discovery:
- App list (list, single, all-workspaces)
- Describe App
- Cross-workspace query

Run:
- Basic App execution
- Streaming output
- HITL (Human-in-the-Loop) — all 19 cases incl. multi-action / expired-token / already-consumed
- File input
- Conversation mode
- Environment variable injection
- Cache and version consistency

Output:
- JSON/YAML output
- Table output

Error Handling:
- Exit code end-to-end validation
- Error message spec

CLI Framework:
- Global Flags
- Non-Interactive mode

Also extend test fixtures:
- scenarios.ts: add hitl-pause-multi-action / hitl-resume-expired-token / hitl-resume-already-consumed
- server.ts: add GET /form/human_input route, multi-action HITL response, expired/consumed token error handling

Known bugs tracked as it.todo:
- WTA-249: server 4xx in -o json mode exit code should be 1 (currently 0 in some cases)
- WTA-252: --help missing GLOBAL FLAGS section and Quick start examples
- WTA-255: hosts.yml YAML parse failure should output JSON envelope
- WTA-257: uncaught TypeError should output JSON envelope in -o json mode
2026-05-22 10:46:18 +08:00
5381452de9 feat(cli,api): difyctl version probes server and reports compat verdict (#36356)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 10:27:34 +08:00
ed5f6b153f fix(cli): color hint lines in BaseError human output
Apply magenta prefix + cyan value to all hint lines rendered via
formatErrorForCli, covering logout, authed-command guard, and any
other BaseError path. Thread isErrTTY from process.stderr.isTTY.
2026-05-19 15:35:11 -07:00
6f760a3901 fix(api): restore conversation history for OPENAPI invoke source
OPENAPI callers (difyctl) don't send parent_message_id, causing
extract_thread_messages to treat each message as a new thread start
and return empty history. Apply UUID_NIL sentinel for OPENAPI the
same way SERVICE_API does, keeping the two invoke sources distinct.
2026-05-19 14:13:24 -07:00
0bf64ca3f2 fix(cli): color hints with magenta prefix and cyan values 2026-05-19 13:53:44 -07:00
e827aca154 fix(cli): hide button_style labels from HITL actions display 2026-05-19 13:45:38 -07:00
8de813c867 feat(cli+api): --file flag support for local upload and remote URL inputs 2026-05-19 13:42:25 -07:00
d09d360530 feat(cli): HITL compact block, resume status framing, spinner fix
- renderHitlBlock: compact inline Actions/Inputs, no blank lines,
  rename Prompt→Message, indent multi-line form_content by 4 spaces
- resumeApp: write ✓ form submitted / workflow execution resumed / ✓ workflow finished
  to stderr in text mode only; JSON/non-text formats are unaffected
- StreamingStructuredStrategy: call spinner.stop() before exit(0) in
  HITL catch to prevent spinner bleed onto pause block output
2026-05-19 00:48:43 -07:00
8a4c87234f chore: switched to shared contract (#36380) 2026-05-19 15:19:41 +08:00
31ea69be66 feat(cli): add release pipeline (#36365) 2026-05-19 14:51:28 +08:00
6de46024a3 fix(cli): CJK table alignment, workflow plain-text output, HITL label rename
- formatTable: use displayWidth() for CJK/fullwidth chars instead of
  String.length, so double-wide characters don't misalign columns
- workflowTextHandler: print value directly when outputs has exactly one
  string key; fall back to JSON for multi-key or non-string outputs
- hitl-render: rename 'Prompt:' label to 'Message:' — less ambiguous
  than 'prompt' in an LLM context
2026-05-18 23:03:07 -07:00
0d5173f73f fix(cli): client-side --mode enum validation, EPIPE guard, stderr redirects
- Reject invalid --mode values before HTTP (e.g. "chatbot" → clear error)
- Add options field to FlagDefinition + validateFlagOptions in parseArgv
- EPIPE exit-0 guard in run.ts catch block
- tree:gen/tree:check diagnostic output redirected to stderr
- Tailwind class sort in web/app/device/components/code-input.tsx
2026-05-18 21:54:49 -07:00
fd1ebdd6cb chore(web): device-flow lint fixes 2026-05-18 21:54:49 -07:00
9fe7adaf69 test(web): terminal state tests — tighter reset assertions + lookup_failed reset coverage 2026-05-18 21:54:49 -07:00
7a6c84dca3 test(web): device-flow terminal state ghost reset buttons 2026-05-18 21:54:49 -07:00
75a8120152 test(web): authorize-sso — add approveExternal args, loadErr, onError, mock reset coverage 2026-05-18 21:54:49 -07:00
ca103b60cc feat(web): device-flow authorize-sso — Avatar card + Button 2026-05-18 21:54:48 -07:00
2c90cfa00f test(web): authorize-account — add onDenied, error-path, mock reset coverage 2026-05-18 21:54:48 -07:00
6851624dbe feat(web): device-flow authorize-account — Avatar card + accountName/avatarUrl props 2026-05-18 21:54:48 -07:00
44d1b66c93 test(web): chooser — add encodeURIComponent + SSO href coverage 2026-05-18 21:54:48 -07:00
f372eb8e5b test(web): assert setPostLoginRedirect called in chooser account button test 2026-05-18 21:54:48 -07:00
36101c7126 feat(web): device-flow chooser — Button + icons 2026-05-18 21:54:48 -07:00
fe212003b1 fix(web): clear errMsg when resetting to code_entry from terminal states 2026-05-18 21:54:47 -07:00
948214fe6a feat(web): device-flow page — signin shell, Button, terminal icons + ghost actions 2026-05-18 21:54:47 -07:00
14328634b5 feat(web): device-flow layout shell + header (signin parity) 2026-05-18 21:54:47 -07:00
de0a44be06 Merge branch 'main' into feat/cli 2026-05-19 10:37:42 +08:00
6153a6b663 [autofix.ci] apply automated fixes 2026-05-19 02:29:28 +00:00
d5dee5326e feat(cli): align HITL pause envelope, split resume into top-level command, JSON purity
- Match server SSE envelope: HitlPausePayload now mirrors {event, task_id,
  workflow_run_id, data:{...}}; renamed user_actions → actions.
- New shared hitl-render.ts: text-mode block, JSON-mode pretty pure JSON
  (no ANSI), colored stderr hint. Exit 0 on pause (no longer a process error).
- Move RunAppResume → top-level ResumeApp at commands/resume/app/; rewire
  imports + regen tree.generated.ts. User-facing strings updated to
  `difyctl resume app`.
- Spinner gated on text mode: structured strategy passes
  enabled = isText && !livePrint so -o json/yaml never spin.
- Hint emits external-channel note when form_token is null (email-only
  delivery to OPENAPI/SERVICE_API surface returns no resume token).
- ColorScheme extended with dim/cyan/green/yellow/magenta methods.
- dify-mock fixture + tests updated to envelope shape; HITL pause test
  split into text + json variants.
2026-05-18 19:25:13 -07:00
49b33647e7 fix(api): map OPENAPI invoke_from to STANDALONE_WEB_APP token surface
OPENAPI workflow runs previously resolved form_token via console-priority
fallback, returning a CONSOLE token that the OPENAPI resume endpoint rejects
with 404. Add a dispatch map so SERVICE_API and OPENAPI invocations both
filter recipients down to their allowed surface; everything else keeps the
priority fallback.
2026-05-18 19:24:53 -07:00
badfd7689a feat(cli): build-time command-tree codegen + hidden/deprecated framework flags
Replaces the hand-written src/commands/tree.ts with a build-time-generated
artifact derived from src/commands/**/index.ts. tree.ts becomes a one-line
re-export of tree.generated.ts. Determinism: lexicographic sort, LF pinned
via .gitattributes, atomic write (tmp + rename), CI-gated by `pnpm tree:check`.

Codegen script (cli/scripts/generate-command-tree.ts) walks the commands
tree, derives canonical PascalCase identifiers (with reserved-word + hyphen
handling), and emits a static ESM module with sorted default imports and a
nested literal of shape CommandTree. Shared exclusion predicate
(isExcludedCommandPath) consumed by both codegen and coverage.test.ts so
underscore-prefixed segments stay non-commands.

Wired pre* lifecycle hooks (prebuild/predev/pretest) and ci composite
gating `tree:check` first. Pack now emits .js outputs (fixedExtension:false)
to drop .mjs; bin/run.js stays on .js. Vitest test.include extended to
cover scripts/.

Framework additions bundled in:
- static hidden = true       omits command from printTopLevelHelp listing
                              (still resolves and runs when invoked)
- static deprecated = '...'  prints "deprecated: <msg>" to stderr before
                              constructing the command

Verified: pnpm ci green (tree:check ok, tsc clean, lint clean, 702 tests
pass, build complete). Smoke: node bin/run.js version + auth login --help,
add-a-command flow, loose-file error case all behave as expected.
2026-05-18 18:25:57 -07:00
0ff00e742f fix(cli): restore BaseError catch routing post-oclif removal
PR #36328 (remove oclif) dropped the DifyCommand.catch() override that
routed BaseError through formatErrorForCli with semantic exit codes.
The replacement catch in framework/run.ts wrote raw err.message and
always exited 1, losing the code prefix, hint, http_status line, JSON
envelope path, and Auth/Usage/VersionCompat exit codes.

framework/run.ts:
- Add sniffOutputFormat(argv) helper: detects --output / -o (= and space
  forms), stops at --, first-occurrence-wins. Schema-free so it survives
  command-construction failures and pre-parse throws.
- Rewrite catch block: branch BaseError -> Error -> non-Error. BaseError
  branch routes through formatErrorForCli({ format: sniffOutputFormat(argv) })
  and exits via err.exit(). Explicit return after each process.exit
  defends against stubbed exits in tests.

run/app/sse-collector.ts:
- decodeStreamError now unwraps openapi-v1 InvokeError envelopes
  ({error_type, args, message}) buried inside env.message. Prefers
  args.description, falls back to inner.message, then raw on shape
  mismatch.

framework/command.ts:
- Sort named imports (fix pre-existing lint error).

Tests (run.test.ts new, sse-collector.test.ts extended):
- 10 sniffOutputFormat cases.
- 12 run() catch-routing cases: BaseError human/JSON, Usage/Server5xx
  exit codes, withRequest method+url in human and JSON, generic Error,
  non-Error throw, success path, constructor-time BaseError, --
  separator.
- 5 decodeStreamError unwrap cases.

Full suite: 675/675. type-check + lint clean. No subclass changes.
2026-05-18 16:34:32 -07:00
a89b43bccc simplify type signature 2026-05-18 20:30:57 +08:00
c6792ce415 [autofix.ci] apply automated fixes 2026-05-18 11:13:20 +00:00
8918142ce1 refactor: remove oclif (#36328) 2026-05-18 19:07:23 +08:00
e2d6ae818c Merge remote-tracking branch 'upstream/main' into feat/cli 2026-05-18 14:00:59 +08:00
2fd7b82970 feat(cli,api): startSpinner export, local install scripts, OPENAPI enum mapping
- spinner.ts: extract startSpinner() returning ActiveSpinner handle for non-blocking use
- scripts/: add install-local.sh / uninstall-local.sh + pnpm install:local / uninstall:local
- api enums: add OPENAPI to InvokeFrom source mapping in both enum definitions
2026-05-17 20:10:01 -07:00
1cc7953f79 feat(cli): add --think flag to strip or show model thinking blocks
By default <think>...</think> blocks emitted by thinking-capable models
are silently stripped from the answer before printing. With --think,
thinking blocks are routed to stderr (with tags preserved) and the
clean answer goes to stdout.

- src/io/think-filter.ts: ThinkChunkFilter (stateful streaming filter
  that handles partial tags split across chunk boundaries), stripThinkBlocks
  and extractThinkBlocks for bulk non-stream processing; \r\n-aware
- stream-handlers.ts: ChatStreamPrinter and CompletionStreamPrinter use
  ThinkChunkFilter; streamPrinterFor(mode, think=false) backward-compat
- streaming-structured.ts: post-collect strip/extract on resp.answer
- streaming-text.ts + print-flags.ts: ctx.think threaded through
- run.ts, resume/run.ts, _strategies/index.ts: think field in types
- run/app/index.ts, resume/index.ts: --think oclif flag, default false
2026-05-17 19:18:34 -07:00
31cf656b35 feat(cli): app run overhaul — always stream, --inputs JSON, HITL pause/resume, Ctrl+C stop
- Remove blocking mode; all apps stream SSE, --stream controls live vs collect output
- Replace --input k=v with --inputs '{json}' (single object, mutually exclusive with --inputs-file)
- Add --workflow-id, --file flags
- HITL: human_input_required → pause JSON to stdout + hint to stderr + exit 2
- Ctrl+C: captures task_id from SSE, calls stop-task, exits 1
- New difyctl run app resume subcommand: POST form, reconnect SSE, stream to completion
- resume: --action (auto-select), --with-history (include_state_snapshot), --stream flags
- Delete BlockingStrategy; simplify pickStrategy(isText, livePrint)
- Add HitlPauseError, SILENT_EVENTS handling in sse-collector and stream-handlers
- Update dify-mock: always SSE, hitl-pause/hitl-resume scenarios, stop/form/events handlers
- Update agent guide: --inputs JSON syntax, HITL pause/resume instructions
2026-05-15 02:53:19 -07:00
8be6665d22 feat(api,cli): openapi HITL endpoints — always-stream, human_input_form, workflow_events, stop-task
- Remove response_mode from AppRunRequest; openapi /run always streams
- Add POST /apps/<id>/tasks/<task_id>/stop (SIGINT hook target)
- Add GET/POST /apps/<id>/form/human_input/<token> (HITL form fetch/submit)
- Add GET /apps/<id>/tasks/<task_id>/events (SSE reconnect after resume)
- Add HumanInputSurface.OPENAPI; map to STANDALONE_WEB_APP recipient type
- Regenerate cli/src/types/data-contracts.ts via pnpm sync-models
2026-05-15 02:50:54 -07:00
c2b91d849d ci(cli): replace Makefile with pnpm scripts, add Dockerfile.dev and CI workflows
- Replace Makefile with package.json scripts (ci, clean, docker:build-dev,
  sync-models, version:info); update cli-tests.yml to use pnpm ci
- Add cli/Dockerfile.dev: multi-stage local build from monorepo source,
  proper cache layer order, non-root user, corepack-driven pnpm version
- Add .dockerignore at repo root to exclude node_modules/dist/.git from
  build context
- Add cli-docker-build.yml: PR + merge_group validation of Dockerfile.dev,
  amd64+arm64 via Depot for org; amd64-only for forks
- cli-release.yml: pin action SHAs, depot-ubuntu-24.04 runner, move
  verify-release step before build for fast-fail
- main-ci.yml: add packages/tsconfig/** and cli-docker-build.yml to cli
  path filter
- Drop cli/docs/specs/** and cli/docs/*.md (superseded)
2026-05-14 22:10:42 -07:00
e0f4e98a2f chore(cli): pre-merge cleanup — docker images, comments, tsconfig lib
- docker-compose.yaml: revert api/web from build: back to image tags
  (1.14.1); fix api_websocket/worker/worker_beat downgraded to 1.14.0
- Remove verbose internal design comments from openapi controllers
- web/next.config.ts: trim anti-framing comment to one line
- cli/tsconfig.json: drop lib:ES2015 override (broke Error.cause typing)
- eslint.config.mjs: ignore cli/context/** and cli/docs/** (local caches)
- pnpm-lock.yaml: regenerate after fresh install
2026-05-14 20:44:51 -07:00
9d554495cf feat(cli): add OPENAPI_ENABLED env switch, default false
Gates the entire /openapi/v1/* blueprint on a single env var.
When false (default), the blueprint is never registered so all
CLI requests return 404 before any auth or logic runs.

Set OPENAPI_ENABLED=true to activate the endpoint group.
2026-05-14 20:16:14 -07:00
c2868075fa Merge remote-tracking branch 'origin/main' into feat/cli
# Conflicts:
#	docker/docker-compose.yaml
#	pnpm-lock.yaml
#	pnpm-workspace.yaml
2026-05-14 20:11:59 -07:00
1a83dfaf1f refactor: use BaseModel in openapi group. Generate ts code from swagger (#36076)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-13 12:56:42 +08:00
83d14e0540 ci(cli): use depot-ubuntu-24.04 for cli-tests jobs
Aligns cli-tests.yml and the cli-tests-skip/finalizer jobs in main-ci.yml
with the rest of the pipeline; all other jobs already use depot-ubuntu-24.04.
2026-05-11 19:46:11 -07:00
1f7da9c191 Merge branch 'main' into feat/cli
Conflicts resolved:
- api/services/app_service.py: extend AppListParams with status + openapi_visible fields so the openapi caller's per-page visibility gate survives the dict->BaseModel refactor; openapi controller now constructs AppListParams.
- pnpm-workspace.yaml: union of CLI-only entries (@napi-rs/keyring, @oclif/*) with main's bumped versions (@next/*, @orpc/*, eslint-plugin-sonarjs, eslint-plugin-storybook); kept eventsource-parser.
- pnpm-lock.yaml: regenerated.
- web/app/signin/utils/post-login-redirect.ts: union impl — keep main's resolvePostLoginRedirect(searchParams) + setOAuthPendingRedirect; add hardened sessionStorage-based setPostLoginRedirect for device flow with same-origin + path whitelist; device redirect takes precedence over oauth pending.
2026-05-11 19:29:37 -07:00
b21d0ae32d fix(cli): call this.parse in arg-less commands to silence oclif UnparsedCommand warning
oclif v4 warns "did not parse its arguments" for any command class whose
run() never calls this.parse(ClassName), independent of whether the
command declares args/flags. Add the call in the five arg-less commands.
2026-05-11 19:01:14 -07:00
6779366dca feat(api,web,cli): difyctl v1.0 — OAuth device flow, /openapi/v1 auth pipeline, CLI client 2026-05-11 18:40:39 -07:00
418 changed files with 33562 additions and 1224 deletions

15
.dockerignore Normal file
View File

@ -0,0 +1,15 @@
**/node_modules
**/.pnpm-store
**/dist
**/.next
**/.turbo
**/.cache
**/__pycache__
**/*.pyc
**/.mypy_cache
**/.ruff_cache
.git
.github
*.md
!web/README.md
!api/README.md

4
.gitattributes vendored
View File

@ -5,3 +5,7 @@
# them. # them.
*.sh text eol=lf *.sh text eol=lf
# Codegen output must stay byte-identical across platforms so
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
*.generated.ts text eol=lf

4
.github/CODEOWNERS vendored
View File

@ -18,6 +18,10 @@
# Docs # Docs
/docs/ @crazywoola /docs/ @crazywoola
# CLI
/cli/ @langgenius/maintainers
/.github/workflows/cli-tests.yml @langgenius/maintainers
# Backend (default owner, more specific rules below will override) # Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost /api/ @QuantumGhost

88
.github/workflows/cli-release.yml vendored Normal file
View File

@ -0,0 +1,88 @@
name: CLI Release
on:
workflow_dispatch:
push:
tags:
- 'difyctl-v*'
concurrency:
group: cli-release-${{ github.ref }}
cancel-in-progress: true
jobs:
release:
name: build standalone binaries (all targets)
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
permissions:
contents: write
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
fetch-depth: 0
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Setup Bun
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
with:
bun-version: latest
- name: Read cli/package.json
id: manifest
run: |
version=$(node -p "require('./package.json').version")
channel=$(node -p "require('./package.json').difyctl.channel")
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
{
echo "version=$version"
echo "channel=$channel"
echo "minDify=$minDify"
echo "maxDify=$maxDify"
} >> "$GITHUB_OUTPUT"
- name: Validate manifest
run: scripts/release-validate-manifest.sh
- name: Install cross-arch native prebuilds
# Re-installs node_modules with every @napi-rs/keyring platform variant
# so `bun build --compile` can embed the right .node into each target.
working-directory: ./
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
- name: Compile standalone binaries (all targets)
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
run: |
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
pnpm build:bin
- name: Generate sha256 checksum file
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
run: scripts/release-write-checksums.sh
- name: Publish GitHub Release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
name: difyctl ${{ steps.manifest.outputs.version }}
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
generate_release_notes: true
fail_on_unmatched_files: true
files: |
cli/dist/bin/difyctl-v*

60
.github/workflows/cli-smoke.yml vendored Normal file
View File

@ -0,0 +1,60 @@
name: CLI Smoke (live dify)
on:
workflow_dispatch:
inputs:
dify_version:
description: "Dify image tag to test against (e.g. 1.7.0)"
type: string
required: true
cli_ref:
description: "Git ref to build the cli from (default: current branch)"
type: string
required: false
permissions:
contents: read
jobs:
smoke:
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout cli ref
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Bring up dify
env:
DIFY_VERSION: ${{ inputs.dify_version }}
run: |
cd docker
cp .env.example .env
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
docker compose up -d api worker web db redis
for i in $(seq 1 60); do
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
echo "dify api ready after ${i}s"
break
fi
sleep 1
done
- name: Run smoke against live dify
working-directory: ./cli
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
- name: Dump dify logs on failure
if: failure()
run: |
cd docker
docker compose logs api worker web --tail=200

46
.github/workflows/cli-tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: CLI Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: cli-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: CLI Tests
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: CI pipeline (typecheck, lint, coverage, build)
run: pnpm ci
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: cli/coverage
flags: cli
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -42,6 +42,7 @@ jobs:
runs-on: depot-ubuntu-24.04 runs-on: depot-ubuntu-24.04
outputs: outputs:
api-changed: ${{ steps.changes.outputs.api }} api-changed: ${{ steps.changes.outputs.api }}
cli-changed: ${{ steps.changes.outputs.cli }}
e2e-changed: ${{ steps.changes.outputs.e2e }} e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }} web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }} vdb-changed: ${{ steps.changes.outputs.vdb }}
@ -62,6 +63,18 @@ jobs:
- 'docker/generate_docker_compose' - 'docker/generate_docker_compose'
- 'docker/ssrf_proxy/**' - 'docker/ssrf_proxy/**'
- 'docker/volumes/sandbox/conf/**' - 'docker/volumes/sandbox/conf/**'
cli:
- 'cli/**'
- 'packages/tsconfig/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- 'eslint.config.mjs'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/cli-tests.yml'
- '.github/workflows/cli-docker-build.yml'
- '.github/actions/setup-web/**'
web: web:
- 'web/**' - 'web/**'
- 'packages/**' - 'packages/**'
@ -184,6 +197,66 @@ jobs:
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1 exit 1
cli-tests-run:
name: Run CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
uses: ./.github/workflows/cli-tests.yml
secrets: inherit
cli-tests-skip:
name: Skip CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped CLI tests
run: echo "No CLI-related changes detected; skipping CLI tests."
cli-tests:
name: CLI Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- cli-tests-run
- cli-tests-skip
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize CLI Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
RUN_RESULT: ${{ needs.cli-tests-run.result }}
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "CLI tests ran successfully."
exit 0
fi
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "CLI tests were skipped because no CLI-related files changed."
exit 0
fi
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run: web-tests-run:
name: Run Web Tests name: Run Web Tests
needs: needs:

7
.gitignore vendored
View File

@ -115,6 +115,12 @@ venv/
ENV/ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
!/cli/src/env/
!/cli/src/commands/env/
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
!/cli/scripts/lib/
.conda/ .conda/
# Spyder project settings # Spyder project settings
@ -247,6 +253,7 @@ scripts/stress-test/reports/
# settings # settings
*.local.json *.local.json
*.local.md *.local.md
*.local.toml
# Code Agent Folder # Code Agent Folder
.qoder/* .qoder/*

View File

@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640 PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600 PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration # Marketplace configuration

View File

@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
ext_logstore, ext_logstore,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_oauth_bearer,
ext_orjson, ext_orjson,
ext_otel, ext_otel,
ext_proxy_fix, ext_proxy_fix,
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry, ext_enterprise_telemetry,
ext_request_logging, ext_request_logging,
ext_session_factory, ext_session_factory,
ext_oauth_bearer,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

View File

@ -11,7 +11,6 @@ from configs import dify_config
from core.helper import encrypter from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.tools.utils.system_encryption import encrypt_system_params from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db from extensions.ext_database import db
from models import Tenant from models import Tenant
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient from models.tools import ToolOAuthSystemClient
from services.plugin.data_migration import PluginDataMigration from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,3 +1,5 @@
from typing import Literal
from pydantic import Field from pydantic import Field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -23,7 +25,7 @@ class DeploymentConfig(BaseSettings):
default=False, default=False,
) )
EDITION: str = Field( EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED", default="SELF_HOSTED",
) )

View File

@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
default=60 * 60, default=60 * 60,
) )
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching tenant plugin model providers in Redis",
default=60 * 60 * 24,
)
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field( PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
description="Maximum allowed size (bytes) for plugin-generated files", description="Maximum allowed size (bytes) for plugin-generated files",
default=50 * 1024 * 1024, default=50 * 1024 * 1024,
@ -525,6 +520,44 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
OPENAPI_ENABLED: bool = Field(
description=(
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
"programmatic clients. Set to true to activate; disabled by default."
),
validation_alias=AliasChoices("OPENAPI_ENABLED"),
default=False,
)
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
description=(
"Comma-separated allowlist for /openapi/v1/* CORS. "
"Default empty = same-origin only. Browser-cookie routes within "
"the group reject cross-origin OPTIONS regardless of this list."
),
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
default="",
)
@computed_field
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
description=(
"Comma-separated client_id values accepted at "
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
"without code changes. Unknown client_id returns 400 unsupported_client."
),
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
default="difyctl",
)
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field( HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10 ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
) )
@ -900,6 +933,17 @@ class AuthConfig(BaseSettings):
default=86400, default=86400,
) )
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
class ModerationConfig(BaseSettings): class ModerationConfig(BaseSettings):
""" """
@ -1186,6 +1230,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task", description="Enable scheduled workflow run cleanup task",
default=False, default=False,
) )
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field( ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task", description="Enable mail clean document notify task",
default=False, default=False,

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, override from typing import Any
from pydantic import Field from pydantic import Field
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
self.namespace = configs["APOLLO_NAMESPACE"] self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace) self.remote_configs = self.client.get_all_dicts(self.namespace)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict): if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}") raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, override from typing import Any
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}") raise RuntimeError(f"Failed to parse config: {e}")
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
field_value = self.remote_configs.get(field_name) field_value = self.remote_configs.get(field_name)
if field_value is None: if field_value is None:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, override, runtime_checkable from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel from pydantic import BaseModel
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
self._config = config or {} self._config = config or {}
self._extensions: dict[str, Any] = {} self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any: def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key.""" """Get configuration value by key."""
return self._config.get(key, default) return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any: def get_extension(self, name: str) -> Any:
"""Get extension by name.""" """Get extension by name."""
return self._extensions.get(name) return self._extensions.get(name)
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
self._extensions[name] = extension self._extensions[name] = extension
@contextmanager @contextmanager
@override
def enter(self) -> Generator[None, None, None]: def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op).""" """Enter null context (no-op)."""
yield yield

View File

@ -6,7 +6,7 @@ import contextvars
import threading import threading
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, final, override from typing import Any, final
from flask import Flask, current_app, g from flask import Flask, current_app, g
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
""" """
self._flask_app = flask_app self._flask_app = flask_app
@override
def get_config(self, key: str, default: Any = None) -> Any: def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config.""" """Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default) return self._flask_app.config.get(key, default)
@override
def get_extension(self, name: str) -> Any: def get_extension(self, name: str) -> Any:
"""Get Flask extension by name.""" """Get Flask extension by name."""
return self._flask_app.extensions.get(name) return self._flask_app.extensions.get(name)
@contextmanager @contextmanager
@override
def enter(self) -> Generator[None, None, None]: def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context.""" """Enter Flask app context."""
with self._flask_app.app_context(): with self._flask_app.app_context():

View File

@ -15,7 +15,6 @@ from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.plugin_service import PluginService
from fields.base import ResponseModel from fields.base import ResponseModel
from graphon.model_runtime.utils.encoders import jsonable_encoder from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class ParserList(BaseModel): class ParserList(BaseModel):

View File

@ -0,0 +1,128 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.device_flow_security import attach_anti_framing
from libs.external_api import ExternalApi
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
attach_anti_framing(bp)
api = ExternalApi(
bp,
version="1.0",
title="OpenAPI",
description="User-scoped programmatic API (bearer auth)",
)
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
# Register response/query models BEFORE importing controller modules so that
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.openapi._models import (
AccountPayload,
AccountResponse,
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppInfoResponse,
AppListQuery,
AppListResponse,
AppListRow,
AppRunRequest,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
MessageMetadata,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
RevokeResponse,
ServerVersionResponse,
SessionListResponse,
SessionRow,
TagItem,
UsageInfo,
WorkflowRunData,
WorkspaceDetailResponse,
WorkspaceListResponse,
WorkspacePayload,
WorkspaceSummaryResponse,
)
from fields.file_fields import FileResponse
register_schema_models(
openapi_ns,
AppDescribeQuery,
AppListQuery,
AppRunRequest,
DeviceCodeRequest,
DevicePollRequest,
DeviceLookupQuery,
DeviceMutateRequest,
PermittedExternalAppsListQuery,
)
register_response_schema_models(
openapi_ns,
TagItem,
UsageInfo,
MessageMetadata,
AppListRow,
AppListResponse,
AppInfoResponse,
AppDescribeInfo,
AppDescribeResponse,
WorkflowRunData,
AccountPayload,
WorkspacePayload,
AccountResponse,
SessionRow,
SessionListResponse,
PermittedExternalAppsListResponse,
RevokeResponse,
WorkspaceSummaryResponse,
WorkspaceListResponse,
WorkspaceDetailResponse,
DeviceCodeResponse,
DeviceLookupResponse,
DeviceMutateResponse,
FileResponse,
ServerVersionResponse,
)
from . import (
_meta,
account,
app_run,
apps,
apps_permitted_external,
files,
human_input_form,
index,
oauth_device,
oauth_device_sso,
workflow_events,
workspaces,
)
# Request models are imported from _models.py and registered above.
__all__ = [
"_meta",
"account",
"app_run",
"apps",
"apps_permitted_external",
"files",
"human_input_form",
"index",
"oauth_device",
"oauth_device_sso",
"workflow_events",
"workspaces",
]
api.add_namespace(openapi_ns)

View File

@ -0,0 +1,66 @@
"""Audit emission for openapi app-run endpoints.
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
matches the existing oauth_device convention. The EE OTel exporter consults
its own allowlist to decide whether to ship the line.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
def emit_app_run(
*,
app_id: str,
tenant_id: str,
caller_kind: str,
mode: str,
surface: str,
) -> None:
logger.info(
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
EVENT_APP_RUN_OPENAPI,
app_id,
tenant_id,
caller_kind,
mode,
surface,
extra={
"audit": True,
"event": EVENT_APP_RUN_OPENAPI,
"app_id": app_id,
"tenant_id": tenant_id,
"caller_kind": caller_kind,
"mode": mode,
"surface": surface,
},
)
def emit_wrong_surface(
*,
subject_type: str | None,
attempted_path: str,
client_id: str | None,
token_id: str | None,
) -> None:
logger.warning(
"audit: %s subject_type=%s attempted_path=%s",
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
subject_type,
attempted_path,
extra={
"audit": True,
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
"subject_type": subject_type,
"attempted_path": attempted_path,
"client_id": client_id,
"token_id": token_id,
},
)

View File

@ -0,0 +1,143 @@
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
from __future__ import annotations
from typing import Any, cast
from controllers.service_api.app.error import AppUnavailableError
from models import App
from models.model import AppMode
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": {},
"required": [],
}
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
def _file_object_shape() -> dict[str, Any]:
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
return {
"type": "object",
"properties": {
"type": {"type": "string"},
"transfer_method": {"type": "string"},
"url": {"type": "string"},
"upload_file_id": {"type": "string"},
},
"additionalProperties": True,
}
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
label = row.get("label") or row.get("variable", "")
base: dict[str, Any] = {"title": label} if label else {}
if row_type in ("text-input", "paragraph"):
out: dict[str, Any] = {"type": "string"} | base
max_length = row.get("max_length")
if isinstance(max_length, int) and max_length > 0:
out["maxLength"] = max_length
return out
if row_type == "select":
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
if row_type == "number":
return {"type": "number"} | base
if row_type == "file":
return _file_object_shape() | base
if row_type == "file-list":
return {
"type": "array",
"items": _file_object_shape(),
} | base
return None
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
"""Translate a user_input_form row list into (properties, required-list).
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
Unknown variable types are skipped (forward-compat).
"""
properties: dict[str, Any] = {}
required: list[str] = []
for row in form:
if not isinstance(row, dict) or len(row) != 1:
continue
((row_type, row_body),) = row.items()
if not isinstance(row_body, dict):
continue
variable = row_body.get("variable")
if not variable:
continue
schema = _row_to_schema(row_type, row_body)
if schema is None:
continue
properties[variable] = schema
if row_body.get("required"):
required.append(variable)
return properties, required
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
Raises `AppUnavailableError` on misconfigured apps.
"""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
return (
workflow.features_dict,
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
)
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
def build_input_schema(app: App) -> dict[str, Any]:
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
completion / workflow: `inputs` object only.
Raises `AppUnavailableError` on misconfigured apps.
"""
_, user_input_form = resolve_app_config(app)
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
properties: dict[str, Any] = {}
required: list[str] = []
if app.mode in _CHAT_FAMILY:
properties["query"] = {"type": "string", "minLength": 1}
required.append("query")
properties["inputs"] = {
"type": "object",
"properties": inputs_props,
"required": inputs_required,
"additionalProperties": False,
}
required.append("inputs")
return {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": properties,
"required": required,
}

View File

@ -0,0 +1,23 @@
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
Returns the server's project version and edition so the difyctl CLI can probe
compatibility without needing to be logged in. Mirrors the `_health` endpoint
in `index.py`.
"""
from flask_restx import Resource
from configs import dify_config
from controllers.openapi import openapi_ns
from controllers.openapi._models import ServerVersionResponse
@openapi_ns.route("/_version")
class VersionApi(Resource):
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
def get(self):
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
return ServerVersionResponse(
version=dify_config.project.version,
edition=edition,
).model_dump(mode="json")

View File

@ -0,0 +1,326 @@
"""Shared response substructures for openapi endpoints."""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from libs.helper import UUIDStrOrEmpty, uuid_value
from models.model import AppMode
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
MAX_PAGE_LIMIT = 200
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class MessageMetadata(BaseModel):
usage: UsageInfo | None = None
retriever_resources: list[dict[str, Any]] = []
class PaginationEnvelope[T](BaseModel):
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
page: int
limit: int
total: int
has_more: bool
data: list[T]
@classmethod
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
class TagItem(BaseModel):
name: str
class AppListRow(BaseModel):
id: str
name: str
description: str | None = None
mode: AppMode
tags: list[TagItem] = []
updated_at: str | None = None
created_by_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
class AppListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class PermittedExternalAppsListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class AppInfoResponse(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author: str | None = None
tags: list[TagItem] = []
class AppDescribeInfo(AppInfoResponse):
updated_at: str | None = None
service_api_enabled: bool
is_agent: bool = False
class AppDescribeResponse(BaseModel):
info: AppDescribeInfo | None = None
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
class ChatMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
conversation_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class CompletionMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class WorkflowRunData(BaseModel):
id: str
workflow_id: str
status: str
outputs: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
class WorkflowRunResponse(BaseModel):
workflow_run_id: str
task_id: str
mode: Literal["workflow"] = "workflow"
data: WorkflowRunData
class AccountPayload(BaseModel):
id: str
email: str
name: str
class WorkspacePayload(BaseModel):
id: str
name: str
role: str
class AccountResponse(BaseModel):
subject_type: str
subject_email: str | None = None
subject_issuer: str | None = None
account: AccountPayload | None = None
workspaces: list[WorkspacePayload] = []
default_workspace_id: str | None = None
class SessionRow(BaseModel):
id: str
prefix: str
client_id: str
device_label: str
created_at: str | None = None
last_used_at: str | None = None
expires_at: str | None = None
class SessionListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[SessionRow]
class RevokeResponse(BaseModel):
status: str
class WorkspaceSummaryResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
class WorkspaceListResponse(BaseModel):
workspaces: list[WorkspaceSummaryResponse]
class WorkspaceDetailResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
created_at: str | None = None
class DeviceCodeResponse(BaseModel):
device_code: str
user_code: str
verification_uri: str
expires_in: int
interval: int
class DeviceLookupResponse(BaseModel):
valid: bool
expires_in_remaining: int = 0
client_id: str | None = None
class DeviceMutateResponse(BaseModel):
status: str
class ServerVersionResponse(BaseModel):
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
version: str
edition: Literal["SELF_HOSTED", "CLOUD"]
class AppDescribeQuery(BaseModel):
"""`?fields=` allow-list for GET /apps/<id>/describe.
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
"""
model_config = ConfigDict(extra="forbid")
fields: set[str] | None = None
workspace_id: str | None = None
@field_validator("workspace_id", mode="before")
@classmethod
def _validate_workspace_id(cls, v: object) -> str | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("workspace_id must be a string")
try:
import uuid as _uuid
_uuid.UUID(v)
except ValueError:
raise ValueError("workspace_id must be a valid UUID")
return v
@field_validator("fields", mode="before")
@classmethod
def _parse_fields(cls, v: object) -> set[str] | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("fields must be a comma-separated string")
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
members = {m.strip() for m in v.split(",") if m.strip()}
unknown = members - _ALLOWED_DESCRIBE_FIELDS
if unknown:
raise ValueError(f"unknown field(s): {sorted(unknown)}")
return members
class AppListQuery(BaseModel):
"""mode is a closed enum."""
workspace_id: str
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
tag: str | None = Field(None, max_length=100)
class AppRunRequest(BaseModel):
inputs: dict[str, Any]
query: str | None = None
files: list[dict[str, Any]] | None = None
conversation_id: UUIDStrOrEmpty | None = None
auto_generate_name: bool = True
workflow_id: str | None = None
workspace_id: UUIDStrOrEmpty | None = None
@field_validator("conversation_id", mode="before")
@classmethod
def _normalize_conv(cls, value: str | None) -> str | None:
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
class DeviceCodeRequest(BaseModel):
client_id: str
device_label: str
class DevicePollRequest(BaseModel):
device_code: str
client_id: str
class DeviceLookupQuery(BaseModel):
user_code: str
class DeviceMutateRequest(BaseModel):
user_code: str
class PermittedExternalAppsListQuery(BaseModel):
"""Strict (extra='forbid')."""
model_config = ConfigDict(extra="forbid")
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)

View File

@ -0,0 +1,169 @@
from __future__ import annotations
from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
MAX_PAGE_LIMIT,
AccountPayload,
AccountResponse,
PaginationEnvelope,
RevokeResponse,
SessionListResponse,
SessionRow,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from services.account_service import AccountService, TenantService
from services.oauth_device_flow import (
list_active_sessions,
revoke_oauth_token,
token_belongs_to_subject,
)
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email,
subject_issuer=ctx.subject_issuer,
account=None,
workspaces=[],
default_workspace_id=None,
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email or (account.email if account else None),
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
).model_dump(mode="json")
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
all_rows = list_active_sessions(db.session, ctx, now)
total = len(all_rows)
sliced = all_rows[(page - 1) * limit : page * limit]
items = [
SessionRow(
id=str(r.id),
prefix=r.prefix,
client_id=r.client_id,
device_label=r.device_label,
created_at=_iso(r.created_at),
last_used_at=_iso(r.last_used_at),
expires_at=_iso(r.expires_at),
)
for r in sliced
]
return (
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
200,
)
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
# 404 (not 403) on cross-subject so the endpoint doesn't leak
# token IDs that belong to other subjects.
if not token_belongs_to_subject(db.session, session_id, ctx):
raise NotFound("session not found")
revoke_oauth_token(db.session, redis_client, session_id)
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt.isoformat().replace("+00:00", "Z")
def _pick_default_workspace(memberships) -> str | None:
if not memberships:
return None
for join, tenant in memberships:
if getattr(join, "current", False):
return str(tenant.id)
return str(memberships[0][1].id)
def _workspace_payload(row) -> WorkspacePayload:
join, tenant = row
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
def _account_payload(account) -> AccountPayload:
return AccountPayload(id=str(account.id), email=account.email, name=account.name)

View File

@ -0,0 +1,165 @@
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
from __future__ import annotations
import logging
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import AppRunRequest
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
IsDraftWorkflowError,
WorkflowIdFormatError,
WorkflowNotFoundError,
)
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@contextmanager
def _translate_service_errors() -> Iterator[None]:
try:
yield
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
return AppGenerateService.generate(
app_model=app,
user=caller,
args=args,
invoke_from=InvokeFrom.OPENAPI,
streaming=streaming,
)
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
if not payload.query or not payload.query.strip():
raise UnprocessableEntity("query_required_for_chat")
args = payload.model_dump(exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False
args.setdefault("query", "")
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
if payload.query is not None:
raise UnprocessableEntity("query_not_supported_for_workflow")
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
AppMode.CHAT: _run_chat,
AppMode.AGENT_CHAT: _run_chat,
AppMode.ADVANCED_CHAT: _run_chat,
AppMode.COMPLETION: _run_completion,
AppMode.WORKFLOW: _run_workflow,
}
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
handler = _DISPATCH.get(app_model.mode)
if handler is None:
raise UnprocessableEntity("mode_not_runnable")
try:
stream_obj = handler(app_model, caller, payload)
except HTTPException:
raise
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
emit_app_run(
app_id=app_model.id,
tenant_id=app_model.tenant_id,
caller_kind=caller_kind,
mode=str(app_model.mode),
surface="apps",
)
return helper.compact_generate_response(stream_obj)
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -0,0 +1,270 @@
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last → outermost → publishes the auth ContextVar before `require_scope`
reads it.
"""
from __future__ import annotations
import uuid as _uuid
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.common.fields import Parameters
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
from controllers.openapi._models import (
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppListQuery,
AppListResponse,
AppListRow,
TagItem,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
get_auth_ctx,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
_EMPTY_PARAMETERS: dict[str, Any] = {
"opening_statement": None,
"suggested_questions": [],
"user_input_form": [],
"file_upload": None,
"system_parameters": {},
}
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = get_auth_ctx()
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
except ValueError:
parsed_uuid = None
is_uuid = False
if is_uuid:
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None:
raise NotFound("app not found")
else:
if not workspace_id:
raise UnprocessableEntity("workspace_id is required for name-based lookup")
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
if len(matches) == 0:
raise NotFound("app not found")
if len(matches) > 1:
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
for m in matches:
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
raise Conflict("".join(lines))
app = matches[0]
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
def parameters_payload(app: App) -> dict:
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
features_dict, user_input_form = resolve_app_config(app)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
def get(self, app_id: str):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app, _ = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
want_params = requested is None or "parameters" in requested
want_schema = requested is None or "input_schema" in requested
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
tags=[TagItem(name=t.name) for t in app.tags],
author=app.author_name,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
is_agent=app.mode in ("agent-chat", "advanced-chat"),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return (
AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
).model_dump(mode="json", exclude_none=False),
200,
)
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
def get(self):
ctx: AuthContext = get_auth_ctx()
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
mode="json"
),
200,
)
if query.name:
try:
parsed_uuid = _uuid.UUID(query.name)
except ValueError:
parsed_uuid = None
else:
parsed_uuid = None
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None or str(app.tenant_id) != workspace_id:
return empty
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
item = AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[TagItem(name=t.name) for t in app.tags],
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=getattr(app, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
return env.model_dump(mode="json"), 200
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
params = AppListParams(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else "all", # type:ignore
name=query.name,
tag_ids=tag_ids,
status="normal",
# Visibility gate pushed into the query — pagination.total stays
# consistent across pages because invisible rows never count.
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
if pagination is None:
return empty
tenant_name = None
if pagination.items:
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
items = [
AppListRow(
id=str(r.id),
name=r.name,
description=r.description,
mode=r.mode,
tags=[TagItem(name=t.name) for t in r.tags],
updated_at=r.updated_at.isoformat() if r.updated_at else None,
created_by_name=getattr(r, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
for r in pagination.items
]
env = AppListResponse(
page=query.page,
limit=query.limit,
total=cast(int, pagination.total),
has_more=query.page * query.limit < cast(int, pagination.total),
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,102 @@
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
(public / sso_verified). License-gated: CE deploys never enable the
EE blueprint chain so this module is unreachable there.
"""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AppListRow,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
def get(self):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
page_result = list_permitted_apps(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else None,
name=query.name,
)
if not page_result.app_ids:
env = PermittedExternalAppsListResponse(
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
)
return env.model_dump(mode="json"), 200
apps_by_id: dict[str, App] = {
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
}
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
items: list[AppListRow] = []
for app_id in page_result.app_ids:
app = apps_by_id.get(app_id)
if not app or app.status != "normal":
continue
tenant = tenants_by_id.get(str(app.tenant_id))
items.append(
AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[], # tenant-scoped; not surfaced cross-tenant
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=None, # cross-tenant author leak prevention
workspace_id=str(app.tenant_id),
workspace_name=tenant.name if tenant else None,
)
)
env = PermittedExternalAppsListResponse(
page=query.page,
limit=query.limit,
total=page_result.total,
has_more=query.page * query.limit < page_result.total,
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,3 @@
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -0,0 +1,46 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline — they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -0,0 +1,68 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context — handlers
read populated values via the decorator's kwargs unpacking.
Context is intentionally decoupled from Flask's ``Request``: the pipeline
guard extracts whatever transport-level inputs the steps need (bearer
token, path params) at the boundary and writes them into Context fields,
so steps stay testable without a request object and won't leak coupling
to a specific framework.
"""
from __future__ import annotations
import uuid
from collections.abc import Mapping
from contextvars import Token
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from werkzeug.exceptions import Unauthorized
from libs.oauth_bearer import AuthContext, Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
required_scope: Scope
bearer_token: str | None = None
path_params: Mapping[str, str] = field(default_factory=dict)
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
auth_ctx_reset_token: Token[AuthContext] | None = None
@property
def must_tenant(self) -> Tenant:
if not self.tenant:
raise Unauthorized("tenant is not associated")
return self.tenant
@property
def must_subject_type(self) -> SubjectType:
if not self.subject_type:
raise Unauthorized("subject_type unset — BearerCheck did not run")
return self.subject_type
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -0,0 +1,51 @@
"""Pipeline IS the auth scheme.
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from functools import wraps
from flask import request
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
# Extract transport-level inputs at the boundary so steps
# stay decoupled from Flask's request object.
ctx = Context(
required_scope=scope,
bearer_token=extract_bearer(request),
path_params=dict(request.view_args or {}),
)
try:
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
finally:
if ctx.auth_ctx_reset_token is not None:
reset_auth_ctx(ctx.auth_ctx_reset_token)
return decorated
return decorator

View File

@ -0,0 +1,170 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also publishes the
resolved identity to the openapi auth ``ContextVar`` (the same one the
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
surface gate and any handler reading the request-scoped context has a single
source of truth across both auth-attach paths. The reset token is stashed
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
its `finally` so worker-thread reuse can't leak identity across requests.
"""
from __future__ import annotations
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
check_workspace_membership,
get_authenticator,
set_auth_ctx,
)
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also publishes the resolved `AuthContext` via
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
``validate_bearer`` writes — so the surface gate + downstream readers
don't see two different identity sources. The reset token is parked on
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
def __call__(self, ctx: Context) -> None:
if not ctx.bearer_token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(ctx.bearer_token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if the resolved subject is not in `accepted`."""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route — that is the design lock-in (no body /
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
``ctx.path_params`` at the boundary so this step doesn't need to know
about the request object.
"""
def __call__(self, ctx: Context) -> None:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only — SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.must_tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.must_subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -0,0 +1,168 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny — no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # type:ignore
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -0,0 +1,89 @@
"""Surface gate.
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
step) is the pipeline-level form. Both delegate to `check_surface` so the
audit emit + canonical-path message are single-sourced.
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
``openapi.wrong_surface_denied``.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from flask import request
from werkzeug.exceptions import Forbidden
from controllers.openapi._audit import emit_wrong_surface
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
_CANONICAL_PATH: dict[SubjectType, str] = {
SubjectType.ACCOUNT: "/openapi/v1/apps",
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
}
F = TypeVar("F", bound=Callable[..., object])
def check_surface(accepted: frozenset[SubjectType]) -> None:
"""Enforce that the resolved subject is in ``accepted``.
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
set the bearer layer didn't run — that's a wiring bug, not a
user-driven failure, so surface it as a ``RuntimeError`` instead of
a silent 403.
"""
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
)
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
if subject in accepted:
return
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
emit_wrong_surface(
subject_type=subject.value if subject else None,
attempted_path=request.path,
client_id=getattr(ctx, "client_id", None),
token_id=_stringify(getattr(ctx, "token_id", None)),
)
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
accepted_set: frozenset[SubjectType] = frozenset(accepted)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
check_surface(accepted_set)
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco
def _coerce_subject_type(raw: object) -> SubjectType | None:
if raw is None:
return None
if isinstance(raw, SubjectType):
return raw
if isinstance(raw, str):
return SubjectType(raw)
return None
def _stringify(value: object) -> str | None:
if value is None:
return None
return str(value)

View File

@ -0,0 +1,72 @@
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from werkzeug.exceptions import BadRequest
import services
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from extensions.ext_database import db
from fields.file_fields import FileResponse
from libs.oauth_bearer import Scope
from models import Account, App
from services.file_service import FileService
@openapi_ns.route("/apps/<string:app_id>/files/upload")
class AppFileUploadApi(Resource):
@openapi_ns.doc("upload_file_for_app_input")
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
@openapi_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request — no file or filename missing",
401: "Unauthorized — invalid or expired bearer token",
413: "File too large",
415: "Unsupported file type or blocked extension",
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.stream.read(),
mimetype=file.mimetype,
user=caller,
)
except ValueError as exc:
raise BadRequest(str(exc))
except services.errors.file.FileTooLargeError as exc:
raise FileTooLargeError(exc.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError(exc.description)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201

View File

@ -0,0 +1,107 @@
"""
OpenAPI bearer-authed human input form endpoints.
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
"""
from __future__ import annotations
import json
import logging
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
from libs.oauth_bearer import Scope
from models.model import App
from services.human_input_service import FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_openapi(form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
raise NotFound("Form not found")
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
submission_user_id: str | None = None
submission_end_user_id: str | None = None
if caller_kind == "account":
submission_user_id = caller.id
else:
submission_end_user_id = caller.id
if form.recipient_type is None:
logger.warning("Recipient type is None for form, form_token=%s", form_token)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=form.recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,9 @@
from flask_restx import Resource
from controllers.openapi import openapi_ns
@openapi_ns.route("/_health")
class HealthApi(Resource):
def get(self):
return {"ok": True}

View File

@ -0,0 +1,398 @@
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
sub-groups in one module:
Protocol (RFC 8628, public + rate-limited):
POST /oauth/device/code
POST /oauth/device/token
GET /oauth/device/lookup
Approval (account branch, console-cookie authed):
POST /oauth/device/approve
POST /oauth/device/deny
SSO branch lives in oauth_device_sso.py.
"""
from __future__ import annotations
import logging
from typing import Any
from flask import request
from flask_login import login_required
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
LIMIT_DEVICE_CODE_PER_IP,
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
SlowDownDecision,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# =========================================================================
# Validation helpers
# =========================================================================
def _validate_json[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _validate_query[M: BaseModel](model: type[M]) -> M:
try:
return model.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
# =========================================================================
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
# =========================================================================
@openapi_ns.route("/oauth/device/code")
class OAuthDeviceCodeApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
payload = _validate_json(DeviceCodeRequest)
client_id = payload.client_id
device_label = payload.device_label
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
store = DeviceFlowRedis(redis_client)
ip = extract_remote_ip(request)
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
return {
"device_code": device_code,
"user_code": user_code,
"verification_uri": _verification_uri(),
"expires_in": expires_in,
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
}, 200
@openapi_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
def post(self):
payload = _validate_json(DevicePollRequest)
device_code = payload.device_code
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
@openapi_ns.route("/oauth/device/lookup")
class OAuthDeviceLookupApi(Resource):
"""Read-only — public for pre-validate before login. user_code is
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
"""
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
payload = _validate_query(DeviceLookupQuery)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
_device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
return {
"valid": True,
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
"client_id": state.client_id,
}, 200
# =========================================================================
# Approval endpoints — account branch (cookie-authed)
# =========================================================================
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
_APPROVE_GUARD_TTL_SECONDS = 10
@openapi_ns.route("/oauth/device/approve")
class DeviceApproveApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
# SET NX guard — without it, two in-flight approves both pass
# PENDING, both mint, and the second upsert silently rotates the
# first caller into an already-revoked token.
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
return {"error": "approve_in_progress"}, 409
try:
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=tenant)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=account.email,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
poll_payload = _build_account_poll_payload(account, tenant, mint)
try:
store.approve(
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError):
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.exception("device_flow: approve raced on %s", device_code)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
_emit_approve_audit(state, account, tenant, mint)
return {"status": "approved"}, 200
@openapi_ns.route("/oauth/device/deny")
class DeviceDenyApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
try:
store.deny(device_code)
except (StateNotFoundError, InvalidTransitionError):
logger.exception("device_flow: deny raced on %s", device_code)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
return {"status": "denied"}, 200
# =========================================================================
# Helpers
# =========================================================================
def _verification_uri() -> str:
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
if base:
return f"{base.rstrip('/')}/device"
return f"{request.host_url.rstrip('/')}/device"
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id,
state.created_ip,
poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w.id == str(tenant) for w in workspaces):
default_ws_id = str(tenant)
if default_ws_id is None:
for _t, m in rows:
if getattr(m, "current", False):
default_ws_id = str(m.tenant_id)
break
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0].id
payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
"workspaces": [w.model_dump(mode="json") for w in workspaces],
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
return payload
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id,
account.email,
state.client_id,
state.device_label,
mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"token_id": str(mint.token_id),
"subject_type": SubjectType.ACCOUNT,
"subject_email": account.email,
"account_id": str(account.id),
"tenant_id": tenant,
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["full"],
"expires_at": mint.expires_at.isoformat(),
},
)
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id,
state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
"client_id": state.client_id,
"device_label": state.device_label,
},
)

View File

@ -0,0 +1,348 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
EE-only. Browser flow:
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
Function-based (raw @bp.route) rather than Resource classes because the
handlers do redirects + cookie kwargs that don't fit the Resource shape.
"""
from __future__ import annotations
import logging
import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from werkzeug.exceptions import (
BadGateway,
BadRequest,
Conflict,
Forbidden,
NotFound,
Unauthorized,
)
from controllers.openapi import bp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import jws
from libs.device_flow_security import (
APPROVAL_GRANT_COOKIE_NAME,
ApprovalGrantClaims,
approval_grant_cleared_cookie_kwargs,
approval_grant_cookie_kwargs,
consume_approval_grant_nonce,
consume_sso_assertion_nonce,
enterprise_only,
mint_approval_grant,
verify_approval_grant,
)
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
from libs.rate_limit import (
LIMIT_APPROVE_EXT_PER_EMAIL,
LIMIT_SSO_INITIATE_PER_IP,
enforce,
rate_limit,
)
from services.account_service import AccountService
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
# device_code it references.
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
def sso_initiate():
user_code = (request.args.get("user_code") or "").strip().upper()
if not user_code:
raise BadRequest("user_code required")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise BadRequest("invalid_user_code")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
payload={
"redirect_url": "",
"app_code": "",
"intent": "device_flow",
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
)
try:
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
except Exception as e:
logger.warning("sso-initiate: enterprise call failed: %s", e)
raise BadGateway("sso_initiate_failed") from e
url = (reply or {}).get("url")
if not url:
raise BadGateway("sso_initiate_missing_url")
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
resp = redirect(url, code=302)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@bp.route("/oauth/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
keyset = jws.KeySet.from_shared_secret()
try:
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
raise BadRequest("invalid_sso_assertion")
user_code = (claims.get("user_code") or "").strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims["email"]):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
iss = request.host_url.rstrip("/")
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims["email"],
subject_issuer=claims["issuer"],
user_code=user_code,
)
resp = redirect("/device?sso_verified=1", code=302)
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
return resp
@bp.route("/oauth/device/approval-context", methods=["GET"])
@enterprise_only
def approval_context():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("no_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify(
{
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}
), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@enterprise_only
def approve_external():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("invalid_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approve-external: bad cookie: %s", e)
raise Unauthorized("invalid_session") from e
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or csrf_header != claims.csrf_token:
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}
body_user_code = (data.get("user_code") or "").strip().upper()
if body_user_code != claims.user_code:
raise BadRequest("user_code_mismatch")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(claims.user_code)
if found is None:
raise NotFound("user_code_not_pending")
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
raise Forbidden("email_belongs_to_dify_account")
if not consume_approval_grant_nonce(redis_client, claims.nonce):
raise Unauthorized("session_already_consumed")
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=None)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=claims.subject_email,
subject_issuer=claims.subject_issuer,
account_id=None,
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
# SSO branch of the shared PollPayload contract: account/workspace
# fields are zero-filled (`None` / `[]`) for parity with the account
# branch in `oauth_device._build_account_poll_payload`.
poll_payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
"token_id": str(mint.token_id),
}
try:
store.approve(
device_code,
subject_email=claims.subject_email,
account_id=None,
subject_issuer=claims.subject_issuer,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError) as e:
logger.exception("approve-external: state transition raced")
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
resp = make_response(jsonify({"status": "approved"}), 200)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@dataclass(frozen=True)
class _RejectedClaims:
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
event without reaching for the full dataclass.
"""
subject_email: str
subject_issuer: str
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
logger.warning(
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
reason,
extra={
"audit": True,
"event": "oauth.device_flow_rejected",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"reason": reason,
"client_id": state.client_id,
"device_label": state.device_label,
},
)
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"token_id": str(mint.token_id),
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["apps:run"],
"expires_at": mint.expires_at.isoformat(),
},
)

View File

@ -0,0 +1,119 @@
"""
OpenAPI bearer-authed workflow reconnect event stream endpoint.
GET /apps/<app_id>/tasks/<task_id>/events
— reconnect to the SSE stream for a paused/running workflow run.
`task_id` is treated as `workflow_run_id`.
"""
from __future__ import annotations
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from libs.oauth_bearer import Scope
from models.enums import CreatorUserRole
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.response(200, "SSE event stream")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if caller_kind == "account":
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
else:
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=caller,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
else:
generator = WorkflowAppGenerator()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.OPENAPI,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -0,0 +1,78 @@
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
counterparts to the cookie-authed /console/api/workspaces endpoints.
Account bearers (dfoa_) see every tenant they're a member of. External
SSO bearers (dfoe_) have no account_id and so see an empty list — that
matches /openapi/v1/account.
"""
from __future__ import annotations
from itertools import starmap
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from models import Tenant, TenantAccountJoin
from services.account_service import TenantService
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = get_auth_ctx()
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = get_auth_ctx()
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
return WorkspaceSummaryResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
)
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
return WorkspaceDetailResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
)

View File

@ -16,7 +16,7 @@ from libs.passport import PassportService
from libs.token import extract_webapp_passport from libs.token import extract_webapp_passport
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService from services.webapp_auth_service import WebAppAuthService
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if not webapp_settings: if not webapp_settings:
raise NotFound("Web app settings not found.") raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public" app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility( _validate_user_accessibility(
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
raise Unauthorized("Please re-login to access the web app.") raise Unauthorized("Please re-login to access the web app.")
app_id = AppService.get_app_id_by_code(app_code) app_id = AppService.get_app_id_by_code(app_code)
app_web_auth_enabled = ( app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public" EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
!= WebAppAccessMode.PUBLIC
) )
if app_web_auth_enabled: if app_web_auth_enabled:
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()

View File

@ -198,7 +198,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
), ),
query=query, query=query,
files=list(file_objs), files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id, user_id=user.id,
stream=streaming, stream=streaming,
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@ -167,7 +167,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
), ),
query=query, query=query,
files=list(file_objs), files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id, user_id=user.id,
stream=streaming, stream=streaming,
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@ -161,7 +161,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
), ),
query=query, query=query,
files=list(file_objs), files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id, user_id=user.id,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,

View File

@ -53,6 +53,14 @@ from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
# Maps the entry surface a workflow was invoked from to the HITL surface that
# its resume tokens must be filtered for. Surfaces not in this map fall back to
# the general priority ordering (typically CONSOLE > BACKSTAGE).
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
}
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
@ -340,11 +348,7 @@ class WorkflowResponseConverter:
form_token_by_form_id = load_form_tokens_by_form_id( form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids, human_input_form_ids,
session=session, session=session,
surface=( surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
) )
# Reconnect paths must preserve the same pause-reason contract as live streams; # Reconnect paths must preserve the same pause-reason contract as live streams;

View File

@ -731,6 +731,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
match invoke_from: match invoke_from:
case InvokeFrom.SERVICE_API: case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.OPENAPI:
created_from = WorkflowAppLogCreatedFrom.OPENAPI
case InvokeFrom.EXPLORE: case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP: case InvokeFrom.WEB_APP:

View File

@ -24,6 +24,7 @@ class UserFrom(StrEnum):
class InvokeFrom(StrEnum): class InvokeFrom(StrEnum):
SERVICE_API = "service-api" SERVICE_API = "service-api"
OPENAPI = "openapi"
WEB_APP = "web-app" WEB_APP = "web-app"
TRIGGER = "trigger" TRIGGER = "trigger"
EXPLORE = "explore" EXPLORE = "explore"
@ -42,6 +43,7 @@ class InvokeFrom(StrEnum):
InvokeFrom.EXPLORE: "explore_app", InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger", InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api", InvokeFrom.SERVICE_API: "api",
InvokeFrom.OPENAPI: "openapi",
} }
return source_mapping.get(self, "dev") return source_mapping.get(self, "dev")

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
from collections.abc import Generator, Iterable, Sequence from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Literal, cast, overload, override from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError from pydantic import ValidationError
@ -12,9 +13,9 @@ from configs import dify_config
from core.llm_generator.output_parser.structured_output import ( from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper, invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
) )
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import ( from graphon.model_runtime.entities.llm_entities import (
LLMResult, LLMResult,
@ -100,36 +101,35 @@ class _PluginStructuredOutputModelInstance:
class PluginModelRuntime(ModelRuntime): class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope. """Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
Provider discovery goes through ``PluginService`` so the plugin lifecycle
methods and provider reads share one tenant-scoped cache owner.
"""
tenant_id: str tenant_id: str
user_id: str | None user_id: str | None
client: PluginModelClient client: PluginModelClient
_plugin_service: type[PluginService] _provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
def __init__( def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
self,
tenant_id: str,
user_id: str | None,
client: PluginModelClient,
plugin_service: type[PluginService],
) -> None:
if client is None: if client is None:
raise ValueError("client is required.") raise ValueError("client is required.")
if plugin_service is None:
raise ValueError("plugin_service is required.")
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.user_id = user_id self.user_id = user_id
self.client = client self.client = client
self._plugin_service = plugin_service self._provider_entities = None
self._provider_entities_lock = Lock()
@override @override
def fetch_model_providers(self) -> Sequence[ProviderEntity]: def fetch_model_providers(self) -> Sequence[ProviderEntity]:
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client) if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
@override @override
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
@ -628,6 +628,34 @@ class PluginModelRuntime(ModelRuntime):
text=text, text=text,
) )
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity: def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers() providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None) provider_entity = next((item for item in providers if item.provider == provider), None)

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.model_providers.base.ai_model import AIModel from graphon.model_runtime.model_providers.base.ai_model import AIModel
@ -118,7 +117,6 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
client=PluginModelClient(), client=PluginModelClient(),
plugin_service=PluginService,
) )

View File

@ -16,7 +16,6 @@ from core.plugin.entities.request import (
TriggerSubscriptionResponse, TriggerSubscriptionResponse,
) )
from core.plugin.impl.trigger import PluginTriggerClient from core.plugin.impl.trigger import PluginTriggerClient
from core.plugin.plugin_service import PluginService
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
from core.trigger.entities.entities import ( from core.trigger.entities.entities import (
EventEntity, EventEntity,
@ -31,6 +30,7 @@ from core.trigger.entities.entities import (
) )
from core.trigger.errors import TriggerProviderCredentialValidationError from core.trigger.errors import TriggerProviderCredentialValidationError
from models.provider_ids import TriggerProviderID from models.provider_ids import TriggerProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -63,7 +63,7 @@ def _get_surface_form_token(
*, *,
surface: HumanInputSurface | None, surface: HumanInputSurface | None,
) -> str | None: ) -> str | None:
if surface == HumanInputSurface.SERVICE_API: if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
for recipient_type, token in recipients: for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token: if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token return token

View File

@ -11,13 +11,15 @@ from models.human_input import RecipientType
class HumanInputSurface(StrEnum): class HumanInputSurface(StrEnum):
SERVICE_API = "service_api" SERVICE_API = "service_api"
CONSOLE = "console" CONSOLE = "console"
OPENAPI = "openapi"
# Service API is intentionally narrower than other surfaces: app-token callers # SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
# should only be able to act on end-user web forms, not internal console flows. # should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = { _ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}), HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}), HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
} }
# A single HITL form can have multiple recipient records; this shared priority # A single HITL form can have multiple recipient records; this shared priority

View File

@ -45,6 +45,7 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"), SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"), SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"), SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
) )
@ -161,6 +162,8 @@ def create_spec_app() -> Flask:
from controllers.console import bp as console_bp from controllers.console import bp as console_bp
from controllers.console import console_ns from controllers.console import console_ns
from controllers.openapi import bp as openapi_bp
from controllers.openapi import openapi_ns
from controllers.service_api import bp as service_api_bp from controllers.service_api import bp as service_api_bp
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.web import bp as web_bp from controllers.web import bp as web_bp
@ -169,8 +172,9 @@ def create_spec_app() -> Flask:
app.register_blueprint(console_bp) app.register_blueprint(console_bp)
app.register_blueprint(web_bp) app.register_blueprint(web_bp)
app.register_blueprint(service_api_bp) app.register_blueprint(service_api_bp)
app.register_blueprint(openapi_bp)
for namespace in (console_ns, web_ns, service_api_ns): for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
for api in namespace.apis: for api in namespace.apis:
_materialize_inline_model_definitions(api) _materialize_inline_model_definitions(api)
@ -201,6 +205,13 @@ def _registered_models(namespace: str) -> dict[str, object]:
for api in service_api_ns.apis: for api in service_api_ns.apis:
models.update(api.models) models.update(api.models)
return models return models
if namespace == "openapi":
from controllers.openapi import openapi_ns
models = dict(openapi_ns.models)
for api in openapi_ns.apis:
models.update(api.models)
return models
raise ValueError(f"unknown Swagger namespace: {namespace}") raise ValueError(f"unknown Swagger namespace: {namespace}")

View File

@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN) FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE) EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id") EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
OPENAPI_MAX_AGE_SECONDS: int = 600
def _apply_cors_once(bp, /, **cors_kwargs): def _apply_cors_once(bp, /, **cors_kwargs):
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
from controllers.files import bp as files_bp from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp from controllers.inner_api import bp as inner_api_bp
from controllers.mcp import bp as mcp_bp from controllers.mcp import bp as mcp_bp
from controllers.openapi import bp as openapi_bp
from controllers.service_api import bp as service_api_bp from controllers.service_api import bp as service_api_bp
from controllers.trigger import bp as trigger_bp from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp from controllers.web import bp as web_bp
@ -41,6 +44,23 @@ def init_app(app: DifyApp):
) )
app.register_blueprint(service_api_bp) app.register_blueprint(service_api_bp)
if dify_config.OPENAPI_ENABLED:
# User-scoped programmatic API. Default empty allowlist = same-origin
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
# integrations. supports_credentials so cookie-authed approve/deny
# work; cross-origin OPTIONS without an allowed origin will fail
# the same as on the console blueprint.
_apply_cors_once(
openapi_bp,
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(OPENAPI_HEADERS),
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
expose_headers=list(EXPOSED_HEADERS),
max_age=OPENAPI_MAX_AGE_SECONDS,
)
app.register_blueprint(openapi_bp)
_apply_cors_once( _apply_cors_once(
web_bp, web_bp,
resources={ resources={

View File

@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task", "task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"), "schedule": crontab(minute="0", hour="0"),
} }
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
imports.append("schedule.clean_oauth_access_tokens_task")
beat_schedule["clean_oauth_access_tokens_task"] = {
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task") imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = { beat_schedule["workflow_schedule_task"] = {

View File

@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from libs.token import extract_access_token, extract_webapp_passport from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser from models.model import AppMCPServer, EndUser
from services.account_service import AccountService from services.account_service import AccountService
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
logged_in_account = AccountService.load_logged_in_account(account_id=user_id) logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account return logged_in_account
elif request.blueprint == "openapi":
# Account-branch device-flow approval routes (approve / deny /
# approval-context) sit under @login_required and authenticate via
# the console session cookie. Cookie-only on purpose — bearer
# tokens (dfoa_/dfoe_) live on the Authorization header and are
# validated by AppPipeline, not flask-login.
cookie_token = extract_console_cookie_token(request)
if not cookie_token:
return None
try:
decoded = PassportService().verify(cookie_token)
except Exception:
return None
user_id = decoded.get("user_id")
source = decoded.get("token_source")
if source or not user_id:
return None
return AccountService.load_logged_in_account(account_id=user_id)
elif request.blueprint == "web": elif request.blueprint == "web":
app_code = request.headers.get(HEADER_NAME_APP_CODE) app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None webapp_token = extract_webapp_passport(app_code, request) if app_code else None

View File

@ -0,0 +1,23 @@
"""Bind the bearer authenticator at startup. Must run after ext_database
and ext_redis (needs both factories).
"""
from __future__ import annotations
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import build_and_bind
def is_enabled() -> bool:
return dify_config.ENABLE_OAUTH_BEARER
def init_app(app: DifyApp) -> None:
# scoped_session isn't a context manager; request teardown closes it.
def session_factory():
return db.session
build_and_bind(session_factory=session_factory, redis_client=redis_client)

View File

@ -0,0 +1,196 @@
"""Device-flow security primitives: enterprise_only gate, approval-grant
cookie mint/verify/consume, and anti-framing headers.
"""
from __future__ import annotations
import logging
import secrets
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from functools import wraps
from flask import Blueprint
from werkzeug.exceptions import NotFound
from libs import jws
from libs.token import is_secure
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ============================================================================
# enterprise_only decorator
# ============================================================================
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
responses don't consume the bucket.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status not in _EE_ENABLED_STATUSES:
raise NotFound()
return view(*args, **kwargs)
return decorated
# ============================================================================
# approval_grant cookie
# ============================================================================
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
@dataclass(frozen=True, slots=True)
class ApprovalGrantClaims:
subject_email: str
subject_issuer: str
user_code: str
nonce: str
csrf_token: str
expires_at: datetime
def mint_approval_grant(
*,
keyset: jws.KeySet,
iss: str,
subject_email: str,
subject_issuer: str,
user_code: str,
) -> tuple[str, ApprovalGrantClaims]:
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
source of truth for Path/HttpOnly/Secure/SameSite.
"""
now = datetime.now(UTC)
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
nonce = _random_opaque()
csrf_token = _random_opaque()
payload = {
"iss": iss,
"subject_email": subject_email,
"subject_issuer": subject_issuer,
"user_code": user_code,
"nonce": nonce,
"csrf_token": csrf_token,
}
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
return token, ApprovalGrantClaims(
subject_email=subject_email,
subject_issuer=subject_issuer,
user_code=user_code,
nonce=nonce,
csrf_token=csrf_token,
expires_at=exp,
)
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
"""Sig + aud + exp only — nonce consumption is the caller's job."""
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
return ApprovalGrantClaims(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
user_code=data["user_code"],
nonce=data["nonce"],
csrf_token=data["csrf_token"],
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
)
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def approval_grant_cookie_kwargs(value: str) -> dict:
"""``secure`` follows is_secure() so HTTP-only deployments don't
silently drop the cookie.
"""
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": value,
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def approval_grant_cleared_cookie_kwargs() -> dict:
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": "",
"max_age": 0,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def _random_opaque() -> str:
return secrets.token_urlsafe(16)
# ============================================================================
# Anti-framing headers
# ============================================================================
_ANTI_FRAMING_HEADERS = {
"X-Frame-Options": "DENY",
"Content-Security-Policy": "frame-ancestors 'none'",
}
def attach_anti_framing(bp: Blueprint) -> None:
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
@bp.after_request
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
for name, value in _ANTI_FRAMING_HEADERS.items():
response.headers.setdefault(name, value)
return response

View File

@ -76,6 +76,7 @@ def register_external_error_handlers(api: Api):
def handle_value_error(e: ValueError): def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e) got_request_exception.send(current_app, exception=e)
current_app.logger.exception("value_error in request handler")
status_code = 400 status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code} data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code return data, status_code

View File

@ -595,3 +595,18 @@ class RateLimiter:
self._redis_client.zadd(key, {member: current_time}) self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2) self._redis_client.expire(key, self.time_window * 2)
def seconds_until_available(self, email: str) -> int:
"""Seconds until the oldest in-window entry expires, freeing a slot.
Defensive floor of 1 second. Caller should only invoke this after
is_rate_limited() returned True.
"""
key = self._get_key(email)
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
if not oldest:
return 1
_member, score = oldest[0]
free_at = int(score) + self.time_window
remaining = free_at - int(time.time())
return max(remaining, 1)

108
api/libs/jws.py Normal file
View File

@ -0,0 +1,108 @@
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
state envelope, external subject assertion, and approval-grant cookie —
all three share one key-set so api ↔ enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
ACTIVE_KID_V1 = "dify-shared-v1"
class KeySetError(Exception):
pass
class KeySet:
"""``from_entries`` reserves multi-kid construction for rotation slots."""
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
if active_kid not in entries:
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
if not entries[active_kid]:
raise KeySetError(f"active kid {active_kid!r} has empty secret")
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> KeySet:
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
return cls(entries, active_kid)
@property
def active_kid(self) -> str:
return self._active_kid
def lookup(self, kid: str) -> bytes | None:
return self._entries.get(kid)
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
"""``iat`` + ``exp`` are injected here; callers must not set them."""
if "aud" in payload or "iat" in payload or "exp" in payload:
raise ValueError("reserved claim present in payload (aud/iat/exp)")
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
kid = keyset.active_kid
secret = keyset.lookup(kid)
if secret is None:
raise KeySetError(f"active kid {kid!r} lookup miss")
iat = datetime.now(UTC)
exp = iat + timedelta(seconds=ttl_seconds)
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
return jwt.encode(
claims,
secret,
algorithm="HS256",
headers={"kid": kid, "typ": "JWT"},
)
class VerifyError(Exception):
pass
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
"""Unknown kid is rejected — never fall back to the active kid, since
a past kid value would otherwise be forgeable by anyone who saw it.
"""
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as e:
raise VerifyError(f"decode header: {e}") from e
kid = header.get("kid")
if not kid:
raise VerifyError("no kid in header")
secret = keyset.lookup(kid)
if secret is None:
raise VerifyError(f"unknown kid {kid!r}")
try:
return jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=expected_aud,
)
except jwt.ExpiredSignatureError as e:
raise VerifyError("token expired") from e
except jwt.InvalidAudienceError as e:
raise VerifyError("aud mismatch") from e
except jwt.PyJWTError as e:
raise VerifyError(f"decode: {e}") from e

685
api/libs/oauth_bearer.py Normal file
View File

@ -0,0 +1,685 @@
"""OAuth bearer primitives.
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from collections.abc import Callable, Iterable
from contextvars import ContextVar, Token
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import request
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.rate_limit import enforce_bearer_rate_limit
from models import Account, OAuthAccessToken, TenantAccountJoin
logger = logging.getLogger(__name__)
# ============================================================================
# Contract — types, enums, protocols
# ============================================================================
class SubjectType(StrEnum):
ACCOUNT = "account"
EXTERNAL_SSO = "external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
"""
FULL = "full"
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
class Accepts(StrEnum):
"""Subject types a route is willing to accept as caller."""
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Per-request identity published via :data:`_auth_ctx_var`
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
``subject_type`` / ``source`` come from the TokenKind, not the DB —
corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
authenticate time. Per-request mutations write through to Redis via
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
_auth_ctx_var: ContextVar[AuthContext] = ContextVar("openapi_auth_ctx")
def set_auth_ctx(ctx: AuthContext) -> Token[AuthContext]:
return _auth_ctx_var.set(ctx)
def reset_auth_ctx(token: Token[AuthContext]) -> None:
_auth_ctx_var.reset(token)
def get_auth_ctx() -> AuthContext:
return _auth_ctx_var.get()
def try_get_auth_ctx() -> AuthContext | None:
return _auth_ctx_var.get(None)
@dataclass(frozen=True, slots=True)
class ResolvedRow:
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
token_id: uuid.UUID
expires_at: datetime | None
verified_tenants: dict[str, bool] = field(default_factory=dict)
def to_cache(self) -> dict:
return {
"subject_email": self.subject_email,
"subject_issuer": self.subject_issuer,
"account_id": str(self.account_id) if self.account_id else None,
"client_id": self.client_id,
"token_id": str(self.token_id),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"verified_tenants": dict(self.verified_tenants),
}
@classmethod
def from_cache(cls, data: dict) -> ResolvedRow:
return cls(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
client_id=data.get("client_id"),
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
)
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
on all live deployments (60s TTL → safe to drop one release after rollout).
"""
if not isinstance(raw, dict):
return {}
out: dict[str, bool] = {}
for k, v in raw.items():
if isinstance(v, bool):
out[k] = v
elif v == "ok":
out[k] = True
elif v == "denied":
out[k] = False
return out
class Resolver(Protocol):
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
...
@dataclass(frozen=True, slots=True)
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
return token.startswith(self.prefix)
@dataclass(frozen=True, slots=True)
class MintProfile:
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
Consumers:
- ``build_registry`` reads scopes here so the resolve-time TokenKind
cannot drift from the mint-time intent.
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
the (subject_type, prefix, scopes) triple a caller intends to mint
against this table — a caller that assembles its own scope set
from a non-canonical source will fail closed at approve time.
"""
subject_type: SubjectType
prefix: str
scopes: frozenset[Scope]
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
SubjectType.ACCOUNT: MintProfile(
subject_type=SubjectType.ACCOUNT,
prefix="dfoa_",
scopes=frozenset({Scope.FULL}),
),
SubjectType.EXTERNAL_SSO: MintProfile(
subject_type=SubjectType.EXTERNAL_SSO,
prefix="dfoe_",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
),
}
class InvalidBearerError(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpiredError(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
# ============================================================================
# Registry
# ============================================================================
class TokenKindRegistry:
def __init__(self, kinds: Iterable[TokenKind]) -> None:
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
prefixes = [k.prefix for k in self._kinds]
if len(set(prefixes)) != len(prefixes):
raise ValueError(f"duplicate prefix in registry: {prefixes}")
def find(self, token: str) -> TokenKind | None:
for k in self._kinds:
if k.matches(token):
return k
return None
def kinds(self) -> tuple[TokenKind, ...]:
return self._kinds
# ============================================================================
# Authenticator
# ============================================================================
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
@property
def registry(self) -> TokenKindRegistry:
return self._registry
def authenticate(self, token: str) -> AuthContext:
"""Identity + per-token rate limit (single source).
Both the openapi pipeline (`BearerCheck`) and the decorator
(`validate_bearer`) call this — rate-limit fires exactly once per
request regardless of which path hosts the route.
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("unknown token prefix")
token_hash = sha256_hex(token)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("token unknown or revoked")
enforce_bearer_rate_limit(token_hash)
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=row.account_id,
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
)
# ============================================================================
# OAuth access token resolver (PAT resolver would be a sibling class)
# ============================================================================
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
POSITIVE_TTL_SECONDS = 60
NEGATIVE_TTL_SECONDS = 10
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
ScopeVariant = Literal["account", "external_sso"]
class OAuthAccessTokenResolver:
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
sharing DB + cache plumbing.
"""
def __init__(
self,
session_factory,
redis_client,
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
def for_account(self) -> Resolver:
return _VariantResolver(self, variant="account")
def for_external_sso(self) -> Resolver:
return _VariantResolver(self, variant="external_sso")
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return "invalid"
try:
return ResolvedRow.from_cache(json.loads(text))
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(row.to_cache()),
)
def cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent.
"""
stmt = (
update(OAuthAccessToken)
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 1: # type: ignore
logger.warning(
"audit: %s token_id=%s",
AUDIT_OAUTH_EXPIRED,
row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self.cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent.cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
if not self._matches_variant(cached):
return None
return cached
# Flask-SQLAlchemy's scoped_session is request-bound and not a
# context manager; use it directly.
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent.cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent.hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id,
row.prefix,
)
return None
resolved = ResolvedRow(
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
client_id=row.client_id,
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent.cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account
return not has_account
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account and row.prefix == "dfoa_"
return (not has_account) and row.prefix == "dfoe_"
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
return (
session.query(OAuthAccessToken)
.filter(
OAuthAccessToken.token_hash == token_hash,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
# ============================================================================
# Layer 0 — workspace membership cache + helper
# ============================================================================
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
rebuilds via authenticate() and re-runs Layer 0.
"""
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
raw = redis_client.get(cache_key)
if raw is None:
return
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return
try:
data = json.loads(text)
except (ValueError, KeyError):
return
ttl = redis_client.ttl(cache_key)
if ttl <= 0:
return
data.setdefault("verified_tenants", {})[tenant_id] = verdict
redis_client.setex(cache_key, ttl, json.dumps(data))
def check_workspace_membership(
*,
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
inline helper (`require_workspace_member`). Caller is responsible for
short-circuiting on EE / SSO subjects before invoking — this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
raise Forbidden("workspace_membership_revoked")
join = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.account_id == account_id,
TenantAccountJoin.tenant_id == tenant_id,
)
).scalar_one_or_none()
if join is None:
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
if status != "active":
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
record_layer0_verdict(token_hash, tenant_id, True)
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
(no `tenant_account_joins` row by definition).
"""
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
return
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
)
# ============================================================================
# Decorator — route-level bearer gate
# ============================================================================
_authenticator: BearerAuthenticator | None = None
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
global _authenticator
_authenticator = authenticator
def get_authenticator() -> BearerAuthenticator:
if _authenticator is None:
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
return _authenticator
def extract_bearer(req) -> str | None:
"""Pull the bearer token out of an HTTP request's Authorization header.
Used by both attachment paths (the ``validate_bearer`` decorator and the
openapi ``Pipeline.guard``) so the parsing rule lives in one place. Pipeline
callers extract once at the boundary and pass the token through ``Context``
so steps stay independent of the request object.
"""
header = req.headers.get("Authorization", "")
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value:
return None
return value.strip()
_DP = ParamSpec("_DP")
_DR = TypeVar("_DR")
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
"""Opt-in: omitting it leaves the route unauthenticated.
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
and are rejected here as the wrong auth scheme for this surface.
"""
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
@wraps(fn)
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
token = extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
if _authenticator is None:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
raise Forbidden("token subject type not accepted here")
# Try/finally pairing — the WSGI worker thread is reused
# across requests, so a leaked ContextVar would publish the
# previous caller's identity to the next request.
reset_token = set_auth_ctx(ctx)
try:
return fn(*args, **kwargs)
finally:
reset_auth_ctx(reset_token)
return inner
return wrap
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
return fn(*args, **kwargs)
return inner
def require_scope(scope: Scope) -> Callable:
"""Route-level scope gate — must run AFTER validate_bearer so that
the auth ContextVar is set. Raises ``Forbidden('insufficient_scope: <scope>')``
when the bearer lacks both the requested scope and ``Scope.FULL``.
"""
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
return fn(*args, **kwargs)
return inner
return wrap
# ============================================================================
# Wiring — called once from the app factory
# ============================================================================
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
return TokenKindRegistry(
[
TokenKind(
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]
)
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
registry = build_registry(session_factory, redis_client)
auth = BearerAuthenticator(registry)
bind_authenticator(auth)
return auth

147
api/libs/rate_limit.py Normal file
View File

@ -0,0 +1,147 @@
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
window Redis ZSET). Apply after auth decorators so account/email/token-id
scopes can read the openapi auth ContextVar (see
:func:`libs.oauth_bearer.try_get_auth_ctx`). Use :func:`enforce` when the
bucket key is computed in-handler. RFC-8628 ``slow_down`` is inline — its
response shape isn't generic 429.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import jsonify, make_response, request, session
from werkzeug.exceptions import TooManyRequests
from configs import dify_config
from libs.helper import RateLimiter, extract_remote_ip
class RateLimitScope(StrEnum):
IP = "ip"
SESSION = "session"
ACCOUNT = "account"
SUBJECT_EMAIL = "subject_email"
TOKEN_ID = "token_id"
@dataclass(frozen=True, slots=True)
class RateLimit:
limit: int
window: timedelta
scopes: tuple[RateLimitScope, ...]
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_BEARER_PER_TOKEN = RateLimit(
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
window=timedelta(minutes=1),
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
)
def _one_key(scope: RateLimitScope) -> str:
match scope:
case RateLimitScope.IP:
return f"ip:{extract_remote_ip(request) or 'unknown'}"
case RateLimitScope.SESSION:
return f"session:{session.get('_id', 'anon')}"
case RateLimitScope.ACCOUNT:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.account_id:
return f"account:{ctx.account_id}"
return "account:anon"
case RateLimitScope.SUBJECT_EMAIL:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.subject_email:
return f"subject:{ctx.subject_email}"
return "subject:anon"
case RateLimitScope.TOKEN_ID:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.token_id:
return f"token:{ctx.token_id}"
return "token:anon"
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
return "|".join(_one_key(s) for s in scopes)
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
return "rl:" + "+".join(s.value for s in scopes)
def _build_limiter(spec: RateLimit) -> RateLimiter:
return RateLimiter(
prefix=_limiter_prefix(spec.scopes),
max_attempts=spec.limit,
time_window=int(spec.window.total_seconds()),
)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Apply after auth decorators that the scopes read from."""
limiter = _build_limiter(spec)
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
key = _composite_key(spec.scopes)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
return fn(*args, **kwargs)
return inner
return wrap
def enforce(spec: RateLimit, *, key: str) -> None:
"""Imperative form — caller composes the bucket key to match scope
semantics (the key is opaque here).
"""
limiter = _build_limiter(spec)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
def enforce_bearer_rate_limit(token_hash: str) -> None:
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
Bucket key = ``token:<sha256_hex>`` so the same token shares one
bucket across api replicas (Redis-backed sliding window).
"""
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
key = f"token:{token_hash}"
if limiter.is_rate_limited(key):
retry_after = limiter.seconds_until_available(key)
response = make_response(
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
429,
)
response.headers["Retry-After"] = str(retry_after)
raise TooManyRequests(response=response)
limiter.increment_rate_limit(key)

View File

@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN)) return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None: def extract_console_cookie_token(request: Request) -> str | None:
def _try_extract_from_cookie(request: Request) -> str | None: """Cookie-only console session token. Used by /openapi/v1/oauth/device/*
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN)) approval routes, which must not fall through to the Authorization header
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_access_token(request: Request) -> str | None:
return extract_console_cookie_token(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None: def extract_webapp_access_token(request: Request) -> str | None:

View File

@ -0,0 +1,128 @@
"""add oauth_access_tokens table
Revision ID: d4a5e1f3c9b7
Revises: f8b6b7e9c421
Create Date: 2026-05-22 17:00:00.000000
Table stores user-level OAuth bearer tokens minted via the device-flow grant
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
table not added in this migration.
Cross-dialect notes:
- UUID columns use ``models.types.StringUUID`` (UUID on PG, CHAR(36) on
MySQL). The application generates ids via ``libs.uuid_utils.uuidv7``;
on PG we additionally set a ``server_default`` so direct SQL inserts
remain valid.
- Indexed text columns are bounded ``VARCHAR(255)`` because MySQL cannot
index ``TEXT`` without an explicit prefix length.
- ``postgresql_where=`` is silently dropped by SQLAlchemy on MySQL, so the
partial-index filters degrade to plain indexes — semantically a
superset, still correct for lookup. The composite unique index on
``(subject_email, subject_issuer, client_id, device_label)`` enforces
uniqueness across both dialects (NULLs are distinct in both, matching
the rotate-in-place contract documented on ``OAuthAccessToken``).
"""
import sqlalchemy as sa
from alembic import op
import models
# revision identifiers, used by Alembic.
revision = "d4a5e1f3c9b7"
down_revision = "f8b6b7e9c421"
branch_labels = None
depends_on = None
def _is_pg() -> bool:
return op.get_bind().dialect.name == "postgresql"
def upgrade():
id_kwargs: dict = {"nullable": False, "primary_key": True}
if _is_pg():
# Match the convention established by 2026_05_19_1000 (uuidv7()).
id_kwargs["server_default"] = sa.text("uuidv7()")
op.create_table(
"oauth_access_tokens",
sa.Column("id", models.types.StringUUID(), **id_kwargs),
sa.Column("subject_email", sa.String(length=255), nullable=False),
sa.Column("subject_issuer", sa.String(length=255), nullable=True),
sa.Column("account_id", models.types.StringUUID(), nullable=True),
sa.Column("client_id", sa.String(length=64), nullable=False),
sa.Column("device_label", sa.String(length=255), nullable=False),
sa.Column("prefix", sa.String(length=8), nullable=False),
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.current_timestamp(),
nullable=False,
),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["account_id"],
["accounts.id"],
name="fk_oauth_access_tokens_account_id",
ondelete="SET NULL",
),
)
# Partial-index WHERE clauses are PG-only (SQLAlchemy drops the kwarg
# on MySQL → plain index, which is still correct for lookup).
op.create_index(
"idx_oauth_subject_email",
"oauth_access_tokens",
["subject_email"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_account",
"oauth_access_tokens",
["account_id"],
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
)
op.create_index(
"idx_oauth_client",
"oauth_access_tokens",
["subject_email", "client_id"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
op.create_index(
"idx_oauth_token_hash",
"oauth_access_tokens",
["token_hash"],
postgresql_where=sa.text("revoked_at IS NULL"),
)
# Rotate-in-place keyed on (subject, client, device). The app always
# writes a non-NULL subject_issuer (account flow uses a sentinel,
# external-SSO uses the verified IdP issuer); without that guarantee
# the composite key would never collide because both PG and MySQL
# treat NULLs as distinct in unique indices.
#
# ``mysql_length`` truncates each text column to 191 chars in the index
# — utf8mb4 makes the per-row index entry (191+191+64+191)*4 = 2548
# bytes, comfortably under InnoDB's 3072-byte index limit. Collisions
# on the 191-char prefix are vanishingly unlikely for real emails /
# OIDC issuers / device labels, and the app re-checks the full-row
# invariant before issuing a rotation.
op.create_index(
"uq_oauth_active_per_device",
"oauth_access_tokens",
["subject_email", "subject_issuer", "client_id", "device_label"],
unique=True,
postgresql_where=sa.text("revoked_at IS NULL"),
mysql_length={"subject_email": 191, "subject_issuer": 191, "device_label": 191},
)
def downgrade():
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
op.drop_table("oauth_access_tokens")

View File

@ -86,7 +86,7 @@ from .model import (
TrialApp, TrialApp,
UploadFile, UploadFile,
) )
from .oauth import DatasourceOauthParamConfig, DatasourceProvider from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
from .provider import ( from .provider import (
LoadBalancingModelConfig, LoadBalancingModelConfig,
Provider, Provider,
@ -199,6 +199,7 @@ __all__ = [
"MessageChain", "MessageChain",
"MessageFeedback", "MessageFeedback",
"MessageFile", "MessageFile",
"OAuthAccessToken",
"OperationLog", "OperationLog",
"PinnedConversation", "PinnedConversation",
"Provider", "Provider",

View File

@ -185,6 +185,7 @@ class InvokeFrom(StrEnum):
DEBUGGER = "debugger" DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published" PUBLISHED_PIPELINE = "published"
VALIDATION = "validation" VALIDATION = "validation"
OPENAPI = "openapi"
@classmethod @classmethod
def value_of(cls, value: str) -> "InvokeFrom": def value_of(cls, value: str) -> "InvokeFrom":
@ -197,6 +198,7 @@ class InvokeFrom(StrEnum):
InvokeFrom.EXPLORE: "explore_app", InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger", InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api", InvokeFrom.SERVICE_API: "api",
InvokeFrom.OPENAPI: "openapi",
} }
return source_mapping.get(self, "dev") return source_mapping.get(self, "dev")

View File

@ -492,8 +492,8 @@ class App(Base):
@property @property
def deleted_tools(self) -> list[DeletedToolInfo]: def deleted_tools(self) -> list[DeletedToolInfo]:
from core.plugin.plugin_service import PluginService
from core.tools.tool_manager import ToolManager, ToolProviderType from core.tools.tool_manager import ToolManager, ToolProviderType
from services.plugin.plugin_service import PluginService
# get agent mode tools # get agent mode tools
app_model_config = self.app_model_config app_model_config = self.app_model_config

View File

@ -84,3 +84,39 @@ class DatasourceOauthTenantParamConfig(TypeBase):
onupdate=func.current_timestamp(), onupdate=func.current_timestamp(),
init=False, init=False,
) )
class OAuthAccessToken(TypeBase):
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
subject_issuer = "dify:account" sentinel); account_id NULL +
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
subject_issuer is non-NULL for all rows the app writes — Postgres
treats NULLs as distinct in unique indices, so the partial unique
index on (subject_email, subject_issuer, client_id, device_label)
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
"""
__tablename__ = "oauth_access_tokens"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
)
# Indexed text columns are bounded VARCHARs so the schema is portable
# across PostgreSQL and MySQL (MySQL cannot index TEXT without a prefix
# length). 255 chars accommodates RFC-compliant emails and typical
# OIDC issuer URLs / device labels.
subject_email: Mapped[str] = mapped_column(sa.String(255), nullable=False)
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
device_label: Mapped[str] = mapped_column(sa.String(255), nullable=False)
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
subject_issuer: Mapped[str | None] = mapped_column(sa.String(255), nullable=True, default=None)
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
)

View File

@ -1209,6 +1209,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
SERVICE_API = "service-api" SERVICE_API = "service-api"
WEB_APP = "web-app" WEB_APP = "web-app"
INSTALLED_APP = "installed-app" INSTALLED_APP = "installed-app"
OPENAPI = "openapi"
@classmethod @classmethod
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom": def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":

View File

@ -0,0 +1,656 @@
# OpenAPI
User-scoped programmatic API (bearer auth)
## Version: 1.0
### Security
**Bearer**
| apiKey | *API Key* |
| ------ | --------- |
| Description | Type: Bearer {your-api-key} |
| In | header |
| Name | Authorization |
---
## openapi
User-scoped operations
### /_health
#### GET
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Success |
### /_version
#### GET
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Server version | [ServerVersionResponse](#serverversionresponse) |
### /account
#### GET
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Account info | [AccountResponse](#accountresponse) |
### /account/sessions
#### GET
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Session list | [SessionListResponse](#sessionlistresponse) |
### /account/sessions/self
#### DELETE
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Session revoked | [RevokeResponse](#revokeresponse) |
### /account/sessions/{session_id}
#### DELETE
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| session_id | path | | Yes | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Session revoked | [RevokeResponse](#revokeresponse) |
### /apps
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| limit | query | | No | integer |
| mode | query | | No | string |
| name | query | | No | string |
| page | query | | No | integer |
| tag | query | | No | string |
| workspace_id | query | | Yes | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | App list | [AppListResponse](#applistresponse) |
### /apps/{app_id}/describe
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| fields | query | | No | [ string ] |
| workspace_id | query | | No | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | App description | [AppDescribeResponse](#appdescriberesponse) |
### /apps/{app_id}/files/upload
#### POST
##### Description
Upload a file to use as an input variable when running the app
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 201 | File uploaded successfully | [FileResponse](#fileresponse) |
| 400 | Bad request — no file or filename missing | |
| 401 | Unauthorized — invalid or expired bearer token | |
| 413 | File too large | |
| 415 | Unsupported file type or blocked extension | |
### /apps/{app_id}/form/human_input/{form_token}
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| form_token | path | | Yes | string |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Form definition |
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| form_token | path | | Yes | string |
| payload | body | | Yes | [HumanInputFormSubmitPayload](#humaninputformsubmitpayload) |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Form submitted |
### /apps/{app_id}/run
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| payload | body | | Yes | [AppRunRequest](#apprunrequest) |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Run result (SSE stream) |
### /apps/{app_id}/tasks/{task_id}/events
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| task_id | path | | Yes | string |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | SSE event stream |
### /apps/{app_id}/tasks/{task_id}/stop
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| app_id | path | | Yes | string |
| task_id | path | | Yes | string |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Task stopped |
### /oauth/device/approve
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| payload | body | | Yes | [DeviceMutateRequest](#devicemutaterequest) |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Approved | [DeviceMutateResponse](#devicemutateresponse) |
### /oauth/device/code
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| payload | body | | Yes | [DeviceCodeRequest](#devicecoderequest) |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Device code created | [DeviceCodeResponse](#devicecoderesponse) |
### /oauth/device/deny
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| payload | body | | Yes | [DeviceMutateRequest](#devicemutaterequest) |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Denied | [DeviceMutateResponse](#devicemutateresponse) |
### /oauth/device/lookup
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| user_code | query | | Yes | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Device lookup result | [DeviceLookupResponse](#devicelookupresponse) |
### /oauth/device/token
#### POST
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| payload | body | | Yes | [DevicePollRequest](#devicepollrequest) |
##### Responses
| Code | Description |
| ---- | ----------- |
| 200 | Success |
### /permitted-external-apps
#### GET
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Permitted external apps list | [PermittedExternalAppsListResponse](#permittedexternalappslistresponse) |
### /workspaces
#### GET
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Workspace list | [WorkspaceListResponse](#workspacelistresponse) |
### /workspaces/{workspace_id}
#### GET
##### Parameters
| Name | Located in | Description | Required | Schema |
| ---- | ---------- | ----------- | -------- | ------ |
| workspace_id | path | | Yes | string |
##### Responses
| Code | Description | Schema |
| ---- | ----------- | ------ |
| 200 | Workspace detail | [WorkspaceDetailResponse](#workspacedetailresponse) |
---
### Models
#### AccountPayload
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| email | string | | Yes |
| id | string | | Yes |
| name | string | | Yes |
#### AccountResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| account | [AccountPayload](#accountpayload) | | No |
| default_workspace_id | string | | No |
| subject_email | string | | No |
| subject_issuer | string | | No |
| subject_type | string | | Yes |
| workspaces | [ [WorkspacePayload](#workspacepayload) ] | | No |
#### AppDescribeInfo
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| author | string | | No |
| description | string | | No |
| id | string | | Yes |
| is_agent | boolean | | No |
| mode | string | | Yes |
| name | string | | Yes |
| service_api_enabled | boolean | | Yes |
| tags | [ [TagItem](#tagitem) ] | | No |
| updated_at | string | | No |
#### AppDescribeQuery
`?fields=` allow-list for GET /apps/<id>/describe.
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| fields | [ string ] | | No |
| workspace_id | string | | No |
#### AppDescribeResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| info | [AppDescribeInfo](#appdescribeinfo) | | No |
| input_schema | object | | No |
| parameters | object | | No |
#### AppInfoResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| author | string | | No |
| description | string | | No |
| id | string | | Yes |
| mode | string | | Yes |
| name | string | | Yes |
| tags | [ [TagItem](#tagitem) ] | | No |
#### AppListQuery
mode is a closed enum.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| limit | integer | | No |
| mode | [AppMode](#appmode) | | No |
| name | string | | No |
| page | integer | | No |
| tag | string | | No |
| workspace_id | string | | Yes |
#### AppListResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| data | [ [AppListRow](#applistrow) ] | | Yes |
| has_more | boolean | | Yes |
| limit | integer | | Yes |
| page | integer | | Yes |
| total | integer | | Yes |
#### AppListRow
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| created_by_name | string | | No |
| description | string | | No |
| id | string | | Yes |
| mode | [AppMode](#appmode) | | Yes |
| name | string | | Yes |
| tags | [ [TagItem](#tagitem) ] | | No |
| updated_at | string | | No |
| workspace_id | string | | No |
| workspace_name | string | | No |
#### AppMode
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| AppMode | string | | |
#### AppRunRequest
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| auto_generate_name | boolean | | No |
| conversation_id | string | | No |
| files | [ object ] | | No |
| inputs | object | | Yes |
| query | string | | No |
| workflow_id | string | | No |
| workspace_id | string | | No |
#### DeviceCodeRequest
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| client_id | string | | Yes |
| device_label | string | | Yes |
#### DeviceCodeResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| device_code | string | | Yes |
| expires_in | integer | | Yes |
| interval | integer | | Yes |
| user_code | string | | Yes |
| verification_uri | string | | Yes |
#### DeviceLookupQuery
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| user_code | string | | Yes |
#### DeviceLookupResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| client_id | string | | No |
| expires_in_remaining | integer | | No |
| valid | boolean | | Yes |
#### DeviceMutateRequest
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| user_code | string | | Yes |
#### DeviceMutateResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| status | string | | Yes |
#### DevicePollRequest
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| client_id | string | | Yes |
| device_code | string | | Yes |
#### FileResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| conversation_id | string | | No |
| created_at | integer | | No |
| created_by | string | | No |
| extension | string | | No |
| file_key | string | | No |
| id | string | | Yes |
| mime_type | string | | No |
| name | string | | Yes |
| original_url | string | | No |
| preview_url | string | | No |
| size | integer | | Yes |
| source_url | string | | No |
| tenant_id | string | | No |
| user_id | string | | No |
#### HumanInputFormSubmitPayload
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| action | string | | Yes |
| inputs | object | | Yes |
#### JsonValue
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| JsonValue | | | |
#### MessageMetadata
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| retriever_resources | [ object ] | | No |
| usage | [UsageInfo](#usageinfo) | | No |
#### PermittedExternalAppsListQuery
Strict (extra='forbid').
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| limit | integer | | No |
| mode | [AppMode](#appmode) | | No |
| name | string | | No |
| page | integer | | No |
#### PermittedExternalAppsListResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| data | [ [AppListRow](#applistrow) ] | | Yes |
| has_more | boolean | | Yes |
| limit | integer | | Yes |
| page | integer | | Yes |
| total | integer | | Yes |
#### RevokeResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| status | string | | Yes |
#### ServerVersionResponse
Meta endpoint payload for `GET /openapi/v1/_version` — no auth required.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| edition | string | *Enum:* `"CLOUD"`, `"SELF_HOSTED"` | Yes |
| version | string | | Yes |
#### SessionListResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| data | [ [SessionRow](#sessionrow) ] | | Yes |
| has_more | boolean | | Yes |
| limit | integer | | Yes |
| page | integer | | Yes |
| total | integer | | Yes |
#### SessionRow
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| client_id | string | | Yes |
| created_at | string | | No |
| device_label | string | | Yes |
| expires_at | string | | No |
| id | string | | Yes |
| last_used_at | string | | No |
| prefix | string | | Yes |
#### TagItem
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| name | string | | Yes |
#### UsageInfo
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| completion_tokens | integer | | No |
| prompt_tokens | integer | | No |
| total_tokens | integer | | No |
#### WorkflowRunData
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| created_at | integer | | No |
| elapsed_time | number | | No |
| error | string | | No |
| finished_at | integer | | No |
| id | string | | Yes |
| outputs | object | | No |
| status | string | | Yes |
| total_steps | integer | | No |
| total_tokens | integer | | No |
| workflow_id | string | | Yes |
#### WorkspaceDetailResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| created_at | string | | No |
| current | boolean | | Yes |
| id | string | | Yes |
| name | string | | Yes |
| role | string | | Yes |
| status | string | | Yes |
#### WorkspaceListResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| workspaces | [ [WorkspaceSummaryResponse](#workspacesummaryresponse) ] | | Yes |
#### WorkspacePayload
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| id | string | | Yes |
| name | string | | Yes |
| role | string | | Yes |
#### WorkspaceSummaryResponse
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| current | boolean | | Yes |
| id | string | | Yes |
| name | string | | Yes |
| role | string | | Yes |
| status | string | | Yes |

View File

@ -0,0 +1,54 @@
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
(token_id stays for audits) so rows accumulate across re-logins, and
expired-but-never-presented rows have no hard-expire trigger — both get
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
"""
from __future__ import annotations
import logging
import time
from datetime import UTC, datetime, timedelta
import click
from sqlalchemy import delete, or_, select
import app
from configs import dify_config
from extensions.ext_database import db
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
DELETE_BATCH_SIZE = 500
@app.celery.task(queue="retention")
def clean_oauth_access_tokens_task():
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
start_at = time.perf_counter()
candidates = or_(
OAuthAccessToken.revoked_at < cutoff,
# Zombies: expired but never re-presented, so middleware never flipped them.
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
)
total = 0
while True:
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
if not ids:
break
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
db.session.commit()
total += len(ids)
end_at = time.perf_counter()
click.echo(
click.style(
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
fg="green",
)
)

View File

@ -8,7 +8,8 @@ from hashlib import sha256
from typing import Any, TypedDict, cast from typing import Any, TypedDict, cast
from pydantic import BaseModel, TypeAdapter, ValidationError from pydantic import BaseModel, TypeAdapter, ValidationError
from sqlalchemy import delete, func, select, update from sqlalchemy import Row, delete, func, select, update
from sqlalchemy.orm import Session, scoped_session
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
@ -163,6 +164,41 @@ class AccountService:
redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@staticmethod
def get_account_by_email(session: Session | scoped_session, email: str) -> Account | None:
"""Plain ``Account`` getter keyed by email. Case-sensitive — use
:meth:`has_active_account_with_email` for the case-insensitive
existence check that backs the SSO collision rule.
"""
return session.execute(select(Account).where(Account.email == email)).scalar_one_or_none()
@staticmethod
def has_active_account_with_email(session: Session | scoped_session, email: str) -> bool:
if not email:
return False
normalized = email.strip().lower()
if not normalized:
return False
row = session.execute(
select(Account.id).where(
func.lower(Account.email) == normalized,
Account.status == AccountStatus.ACTIVE,
)
).scalar_one_or_none()
return row is not None
@staticmethod
def get_account_by_id(session: Session | scoped_session, account_id: str) -> Account | None:
"""Plain ``Account`` getter — no banned check, no tenant rotation,
no ``last_active_at`` write. Use this from read-only identity
endpoints (``/openapi/v1/account``) where ``load_user``'s
side-effects (current-tenant assignment, commit) are unwanted.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
"""
return session.get(Account, account_id)
@staticmethod @staticmethod
def load_user(user_id: str) -> None | Account: def load_user(user_id: str) -> None | Account:
account = db.session.get(Account, user_id) account = db.session.get(Account, user_id)
@ -1182,6 +1218,127 @@ class TenantService:
).all() ).all()
) )
@staticmethod
def get_account_memberships(
session: Session | scoped_session,
account_id: str,
) -> list[Row[tuple[TenantAccountJoin, Tenant]]]:
"""Return ``(TenantAccountJoin, Tenant)`` rows for every workspace
the account belongs to. Unlike :meth:`get_join_tenants` this keeps
the join row so callers can read ``role``/``current`` alongside the
tenant — used by ``/openapi/v1/account`` to render workspace
membership + pick the default workspace.
``session`` is injected by the caller so this service stays free
of the Flask-scoped ``db.session`` import.
No tenant-status filter: parity with the legacy controller query
(the openapi identity endpoint listed all joined tenants).
"""
return (
session.query(TenantAccountJoin, Tenant)
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account_id)
.all()
)
@staticmethod
def get_workspaces_for_account(
session: Session | scoped_session,
account_id: str,
) -> list[Row[tuple[Tenant, TenantAccountJoin]]]:
"""``(Tenant, TenantAccountJoin)`` rows for every workspace the
account belongs to, ordered by ``Tenant.created_at`` ASC — the
canonical ordering for ``/openapi/v1/workspaces``.
Distinct from :meth:`get_account_memberships`: tuple order is
flipped (tenant first) and rows are sorted, so the workspace
listing is stable across requests.
"""
return list(
session.execute(
select(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == account_id)
.order_by(Tenant.created_at.asc())
).all()
)
@staticmethod
def account_belongs_to_tenant(
session: Session | scoped_session,
account_id: uuid.UUID | str | None,
tenant_id: str,
) -> bool:
"""Existence check for ``TenantAccountJoin(account_id, tenant_id)``.
Backs the CE-deployment membership fallback in
``controllers.openapi.auth.strategies.MembershipStrategy``.
``None``/empty ``account_id`` short-circuits to ``False`` so SSO
bearers (no account) and missing identity collapse cleanly.
"""
if not account_id:
return False
row = session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.tenant_id == tenant_id,
TenantAccountJoin.account_id == account_id,
)
).scalar_one_or_none()
return row is not None
@staticmethod
def get_tenant_by_id(session: Session | scoped_session, tenant_id: str) -> Tenant | None:
"""Plain ``session.get(Tenant, tenant_id)`` — no status filter.
Callers map ``status == ARCHIVE`` to their own error code (the
openapi auth pipeline raises 403 ``workspace unavailable``).
"""
return session.get(Tenant, tenant_id)
@staticmethod
def get_tenants_by_ids(
session: Session | scoped_session,
tenant_ids: list[str],
) -> list[Tenant]:
"""Bulk ``Tenant`` fetch by primary-key list. Order is unspecified
— callers index by ``tenant.id`` (e.g. for cross-tenant denorm
in ``/openapi/v1/permitted-external-apps``).
Empty input short-circuits to ``[]`` to avoid emitting an
``IN ()`` SQL fragment.
"""
if not tenant_ids:
return []
return list(session.execute(select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all())
@staticmethod
def get_tenant_name(session: Session | scoped_session, tenant_id: str) -> str | None:
"""Single-column tenant name read. Used by openapi list endpoints
to denormalize ``workspace_name`` onto each row without dragging
the full ``Tenant`` ORM entity through.
"""
return session.execute(select(Tenant.name).where(Tenant.id == tenant_id)).scalar_one_or_none()
@staticmethod
def find_workspace_for_account(
session: Session | scoped_session,
account_id: str,
workspace_id: str,
) -> Row[tuple[Tenant, TenantAccountJoin]] | None:
"""Single ``(Tenant, TenantAccountJoin)`` row scoped to the
account's membership in ``workspace_id``. ``None`` on non-member
— the caller maps that to 404 (not 403) so workspace IDs don't
leak across tenants via response codes.
"""
return session.execute(
select(Tenant, TenantAccountJoin)
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
.where(
Tenant.id == workspace_id,
TenantAccountJoin.account_id == account_id,
)
).first()
@staticmethod @staticmethod
def get_current_tenant_by_account(account: Account): def get_current_tenant_by_account(account: Account):
"""Get tenant by account and add the role""" """Get tenant by account and add the role"""

View File

@ -1,11 +1,13 @@
import json import json
import logging import logging
from collections.abc import Sequence
from typing import Any, Literal, TypedDict, cast from typing import Any, Literal, TypedDict, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask_sqlalchemy.pagination import Pagination from flask_sqlalchemy.pagination import Pagination
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, scoped_session
from configs import dify_config from configs import dify_config
from constants.model_template import default_app_templates from constants.model_template import default_app_templates
@ -26,6 +28,7 @@ from models.tools import ApiToolProvider
from services.billing_service import BillingService from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from services.openapi.visibility import apply_openapi_gate, is_openapi_visible
from services.tag_service import TagService from services.tag_service import TagService
from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task
@ -39,6 +42,8 @@ class AppListParams(BaseModel):
name: str | None = None name: str | None = None
tag_ids: list[str] | None = None tag_ids: list[str] | None = None
is_created_by_me: bool | None = None is_created_by_me: bool | None = None
status: str | None = None
openapi_visible: bool = False
class CreateAppParams(BaseModel): class CreateAppParams(BaseModel):
@ -54,6 +59,51 @@ class CreateAppParams(BaseModel):
class AppService: class AppService:
@staticmethod
def get_app_by_id(
session: Session | scoped_session,
app_id: str,
) -> App | None:
return session.get(App, app_id)
@staticmethod
def get_visible_app_by_id(
session: Session | scoped_session,
app_id: str,
) -> App | None:
app = session.get(App, app_id)
if not app or app.status != "normal" or not is_openapi_visible(app):
return None
return app
@staticmethod
def find_visible_apps_by_ids(
session: Session | scoped_session,
app_ids: Sequence[str],
) -> list[App]:
if not app_ids:
return []
return list(session.execute(apply_openapi_gate(select(App).where(App.id.in_(list(app_ids))))).scalars().all())
@staticmethod
def find_visible_apps_by_name(
session: Session | scoped_session,
*,
name: str,
tenant_id: str,
) -> list[App]:
return list(
session.execute(
apply_openapi_gate(
select(App).where(
App.name == name,
App.tenant_id == tenant_id,
App.status == "normal",
)
)
).scalars()
)
def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None: def get_paginate_apps(self, user_id: str, tenant_id: str, params: AppListParams) -> Pagination | None:
""" """
Get app list with pagination Get app list with pagination
@ -75,6 +125,14 @@ class AppService:
elif params.mode == "agent-chat": elif params.mode == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT) filters.append(App.mode == AppMode.AGENT_CHAT)
if params.status:
filters.append(App.status == params.status)
# OpenAPI surface visibility gate. Pushed into the query so
# `pagination.total` reflects only apps the openapi caller can
# actually reach — post-filtering by enable_api after the page
# arrives would make `total` page-dependent.
if params.openapi_visible:
filters.append(App.enable_api.is_(True))
if params.is_created_by_me: if params.is_created_by_me:
filters.append(App.created_by == user_id) filters.append(App.created_by == user_id)
if params.name: if params.name:

View File

@ -14,13 +14,13 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.plugin.plugin_service import PluginService
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.provider_entities import FormType from graphon.model_runtime.entities.provider_entities import FormType
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -0,0 +1,44 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from werkzeug.exceptions import ServiceUnavailable
from services.enterprise.enterprise_service import EnterpriseService
from services.errors.enterprise import EnterpriseAPIError
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class PermittedAppsPage:
app_ids: list[str]
total: int
has_more: bool
def list_permitted_apps(
*,
page: int,
limit: int,
mode: str | None = None,
name: str | None = None,
) -> PermittedAppsPage:
try:
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
page=page, limit=limit, mode=mode, name=name
)
except EnterpriseAPIError as exc:
logger.warning(
"permitted_apps EE call failed: status=%s message=%s",
getattr(exc, "status_code", None),
str(exc),
)
raise ServiceUnavailable("permitted_apps_unavailable") from exc
return PermittedAppsPage(
app_ids=[row["appId"] for row in body.get("data", [])],
total=int(body.get("total", 0)),
has_more=bool(body.get("hasMore", False)),
)

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import enum
import logging import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
@ -24,10 +25,22 @@ VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppAccessMode(enum.StrEnum):
PUBLIC = "public"
PRIVATE = "private"
PRIVATE_ALL = "private_all"
SSO_VERIFIED = "sso_verified"
PERMISSION_CHECK_MODES: frozenset[WebAppAccessMode] = frozenset(
{WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
)
class WebAppSettings(BaseModel): class WebAppSettings(BaseModel):
access_mode: str = Field( access_mode: str = Field(
description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'", description=f"Access mode for the web app. One of: {', '.join(m.value for m in WebAppAccessMode)}",
default="private", default=WebAppAccessMode.PRIVATE.value,
alias="accessMode", alias="accessMode",
) )
@ -108,6 +121,15 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str): def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
return EnterpriseRequest.send_request(
"POST",
"/device-flow/sso-initiate",
json={"signed_state": signed_state},
raise_for_status=True,
)
@classmethod @classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult: def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
""" """
@ -219,8 +241,9 @@ class EnterpriseService:
def update_app_access_mode(cls, app_id: str, access_mode: str): def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id: if not app_id:
raise ValueError("app_id must be provided.") raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]: allowed = {WebAppAccessMode.PUBLIC, WebAppAccessMode.PRIVATE, WebAppAccessMode.PRIVATE_ALL}
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") if access_mode not in allowed:
raise ValueError(f"access_mode must be one of: {', '.join(m.value for m in allowed)}")
data = {"appId": app_id, "accessMode": access_mode} data = {"appId": app_id, "accessMode": access_mode}
@ -236,6 +259,32 @@ class EnterpriseService:
params = {"appId": app_id} params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def list_externally_accessible_apps(
cls,
*,
page: int,
limit: int,
mode: str | None = None,
name: str | None = None,
) -> dict:
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
"total": int, "hasMore": bool}``.
"""
body: dict[str, str | int] = {"page": page, "limit": limit}
if mode is not None:
body["mode"] = mode
if name is not None:
body["name"] = name
return EnterpriseRequest.send_request(
"POST",
"/webapp/externally-accessible-apps",
json=body,
timeout=5.0,
)
@classmethod @classmethod
def get_cached_license_status(cls) -> LicenseStatus | None: def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls. """Get enterprise license status with Redis caching to reduce HTTP calls.

View File

@ -0,0 +1,572 @@
from __future__ import annotations
import hashlib
import json
import logging
import os
import secrets
import time
import uuid
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import Any, NotRequired, TypedDict
from sqlalchemy import and_, func, select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, scoped_session
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
# ============================================================================
# Redis state machine — device_code + user_code ephemeral state
# ============================================================================
_DEVICE_CODE_KEY_PREFIX = "device_code:"
_USER_CODE_KEY_PREFIX = "user_code:"
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
# not both observe APPROVED — only the winner gets the plaintext token,
# the loser sees nil and the caller maps that to expired_token.
_CONSUME_ON_POLL_LUA = """
local raw = redis.call('GET', KEYS[1])
if not raw then return nil end
local ok, decoded = pcall(cjson.decode, raw)
if not ok then return nil end
if decoded.status == 'pending' then return nil end
if decoded.user_code then
redis.call('DEL', ARGV[1] .. decoded.user_code)
end
redis.call('DEL', KEYS[1])
return raw
"""
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
USER_CODE_SEGMENT_LEN = 4
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
class DeviceFlowStatus(StrEnum):
PENDING = "pending"
APPROVED = "approved"
DENIED = "denied"
class SlowDownDecision(StrEnum):
OK = "ok"
SLOW_DOWN = "slow_down"
class PollPayload(TypedDict):
"""Body served by the unauthenticated poll endpoint
(`POST /openapi/v1/oauth/device/token`) once approve has run.
A single shape across both branches so the CLI/SPA can parse one
contract:
- ``account`` branch (built in `controllers.openapi.oauth_device.
_build_account_poll_payload`) populates ``account`` + ``workspaces``
+ ``default_workspace_id`` and omits the SSO-only fields.
- ``external_sso`` branch (built in
`controllers.openapi.oauth_device_sso.approve_external`) populates
``subject_email`` + ``subject_issuer`` and zero-fills the
account/workspace fields (``None`` / ``[]``).
Pre-rendering here means the unauthenticated poll handler doesn't
re-query accounts/tenants for authz data.
"""
token: str
expires_at: str
subject_type: SubjectType
account: dict[str, object] | None
workspaces: list[dict[str, object]]
default_workspace_id: str | None
token_id: str
subject_email: NotRequired[str]
subject_issuer: NotRequired[str]
@dataclass
class DeviceFlowState:
"""``minted_token`` is plaintext between approve and the next poll;
DEL'd after the poll reads it.
"""
user_code: str
client_id: str
device_label: str
status: DeviceFlowStatus
subject_email: str | None = None
account_id: str | None = None
subject_issuer: str | None = None
minted_token: str | None = None
token_id: str | None = None
created_at: str = ""
created_ip: str = ""
last_poll_at: str = ""
poll_payload: PollPayload | None = field(default=None)
def to_json(self) -> str:
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> DeviceFlowState:
data = json.loads(raw)
if "status" in data:
data["status"] = DeviceFlowStatus(data["status"])
return cls(**data)
def _random_device_code() -> str:
return "dc_" + secrets.token_urlsafe(24)
def _random_user_code_segment() -> str:
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
def _random_user_code() -> str:
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
class StateNotFoundError(Exception):
pass
class InvalidTransitionError(Exception):
pass
class UserCodeExhaustedError(Exception):
pass
class DeviceFlowRedis:
def __init__(self, redis_client) -> None:
self._redis = redis_client
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
device_code = _random_device_code()
user_code = self._claim_user_code(device_code)
state = DeviceFlowState(
user_code=user_code,
client_id=client_id,
device_label=device_label,
status=DeviceFlowStatus.PENDING,
created_at=datetime.now(UTC).isoformat(),
created_ip=created_ip,
)
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
DEVICE_FLOW_TTL_SECONDS,
state.to_json(),
)
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
def _claim_user_code(self, device_code: str) -> str:
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
user_code = _random_user_code()
key = USER_CODE_KEY_FMT.format(code=user_code)
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
if ok:
return user_code
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
if not raw_dc:
return None
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
state = self._load_state(device_code)
if state is None:
return None
return device_code, state
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
return self._load_state(device_code)
def _load_state(self, device_code: str) -> DeviceFlowState | None:
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
if not raw:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.exception("device_flow: corrupt state for %s", device_code)
return None
def approve(
self,
device_code: str,
subject_email: str,
account_id: str | None,
minted_token: str,
token_id: str,
subject_issuer: str | None = None,
poll_payload: PollPayload | None = None,
) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransitionError(f"cannot approve {state.status}")
state.status = DeviceFlowStatus.APPROVED
state.subject_email = subject_email
state.account_id = account_id
state.subject_issuer = subject_issuer
state.minted_token = minted_token
state.token_id = token_id
state.poll_payload = poll_payload
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
def deny(self, device_code: str) -> None:
state = self._load_state(device_code)
if state is None:
raise StateNotFoundError(device_code)
if state.status is not DeviceFlowStatus.PENDING:
raise InvalidTransitionError(f"cannot deny {state.status}")
state.status = DeviceFlowStatus.DENIED
self._redis.setex(
DEVICE_CODE_KEY_FMT.format(code=device_code),
self._remaining_ttl(device_code, floor=1),
state.to_json(),
)
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
single Redis transaction so only one of N concurrent pollers
observes the APPROVED state. Losers get None, mapped to
expired_token by the caller.
"""
raw = self._consume_on_poll_script(
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
args=[_USER_CODE_KEY_PREFIX],
)
if raw is None:
return None
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
try:
return DeviceFlowState.from_json(text_)
except (ValueError, KeyError):
logger.exception("device_flow: corrupt state on consume %s", device_code)
return None
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
now = time.time()
key = f"device_code:{device_code}:last_poll"
prev_raw = self._redis.get(key)
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
if prev_raw is None:
return SlowDownDecision.OK
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
try:
prev = float(prev_s)
except ValueError:
return SlowDownDecision.OK
if now - prev < interval_seconds:
return SlowDownDecision.SLOW_DOWN
return SlowDownDecision.OK
def _remaining_ttl(self, device_code: str, floor: int) -> int:
"""``max(remaining, floor)`` — guarantees the CLI has at least
``floor`` seconds to poll after a near-expiry approve.
"""
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
if ttl is None or ttl < 0:
return floor
return max(int(ttl), floor)
# ============================================================================
# Token mint — generate + upsert
# ============================================================================
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
PREFIX_OAUTH_ACCOUNT = "dfoa_"
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
# Sentinel issuer for account-flow rows. Postgres' default partial unique
# index treats NULLs as distinct, which would let two live `dfoa_` rows
# share (email, client, device) and break rotate-in-place. Storing a
# non-empty literal makes the composite key collide as intended.
ACCOUNT_ISSUER_SENTINEL = "dify:account"
@dataclass(frozen=True, slots=True)
class MintResult:
"""Plaintext token surfaces to the caller once."""
token: str
token_id: uuid.UUID
expires_at: datetime
@dataclass(frozen=True, slots=True)
class UpsertOutcome:
token_id: uuid.UUID
rotated: bool
old_hash: str | None
def generate_token(prefix: str) -> str:
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def mint_oauth_token(
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
# the wrapper proxies the same execute/commit surface.
session: Session | scoped_session,
redis_client,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
ttl_days: int,
) -> MintResult:
"""Live row rotates in place via partial unique index
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
deleted so stale AuthContext drops immediately.
"""
if prefix == PREFIX_OAUTH_ACCOUNT:
# Account flow always writes the sentinel — caller may pass None
# (for clarity) or the sentinel itself; nothing else is valid.
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
subject_issuer = ACCOUNT_ISSUER_SENTINEL
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
# Defense in depth: enterprise canonicalises + rejects empty,
# but a regression there must not yield a NULL composite key here.
if not subject_issuer or not subject_issuer.strip():
raise ValueError("external-SSO token requires non-empty subject_issuer")
else:
raise ValueError(f"unknown oauth prefix: {prefix!r}")
token = generate_token(prefix)
new_hash = sha256_hex(token)
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
outcome = _upsert(
session,
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
new_hash=new_hash,
expires_at=expires_at,
)
if outcome.rotated and outcome.old_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
def _upsert(
session: Session | scoped_session,
*,
subject_email: str,
subject_issuer: str | None,
account_id: str | None,
client_id: str,
device_label: str,
prefix: str,
new_hash: str,
expires_at: datetime,
) -> UpsertOutcome:
# Snapshot prior live row's hash for Redis invalidation post-rotate.
# subject_issuer is always non-null here (account flow uses sentinel,
# external-SSO is validated upstream), so equality matches the index.
prior = session.execute(
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
.where(
OAuthAccessToken.subject_email == subject_email,
OAuthAccessToken.subject_issuer == subject_issuer,
OAuthAccessToken.client_id == client_id,
OAuthAccessToken.device_label == device_label,
OAuthAccessToken.revoked_at.is_(None),
)
.limit(1)
).first()
old_hash = prior.token_hash if prior else None
insert_stmt = pg_insert(OAuthAccessToken).values(
subject_email=subject_email,
subject_issuer=subject_issuer,
account_id=account_id,
client_id=client_id,
device_label=device_label,
prefix=prefix,
token_hash=new_hash,
expires_at=expires_at,
)
upsert_stmt = insert_stmt.on_conflict_do_update(
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
index_where=OAuthAccessToken.revoked_at.is_(None),
set_={
"token_hash": insert_stmt.excluded.token_hash,
"prefix": insert_stmt.excluded.prefix,
"account_id": insert_stmt.excluded.account_id,
"expires_at": insert_stmt.excluded.expires_at,
"created_at": func.now(),
"last_used_at": None,
},
).returning(OAuthAccessToken.id)
row = session.execute(upsert_stmt).first()
session.commit()
if row is None:
raise RuntimeError("oauth_token upsert returned no row")
token_id = uuid.UUID(str(row.id))
return UpsertOutcome(
token_id=token_id,
rotated=prior is not None,
old_hash=old_hash,
)
# ============================================================================
# TTL policy — days new OAuth tokens live
# ============================================================================
DEFAULT_OAUTH_TTL_DAYS = 14
MIN_TTL_DAYS = 1
MAX_TTL_DAYS = 365
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
def oauth_ttl_days(tenant_id: str | None = None) -> int:
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
is deferred; when it lands it wins over the env (Redis-cached 60s).
"""
_ = tenant_id
raw = os.environ.get(_TTL_ENV_VAR)
if raw is None:
return DEFAULT_OAUTH_TTL_DAYS
try:
value = int(raw)
except ValueError:
logger.warning(
"%s=%r is not an int; falling back to %d",
_TTL_ENV_VAR,
raw,
DEFAULT_OAUTH_TTL_DAYS,
)
return DEFAULT_OAUTH_TTL_DAYS
if value < MIN_TTL_DAYS:
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
return MIN_TTL_DAYS
if value > MAX_TTL_DAYS:
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
return MAX_TTL_DAYS
return value
def subject_match_clauses(ctx: AuthContext) -> tuple[Any, ...]:
if ctx.subject_type == SubjectType.ACCOUNT:
return (OAuthAccessToken.account_id == str(ctx.account_id),)
return (
OAuthAccessToken.subject_email == ctx.subject_email,
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
OAuthAccessToken.account_id.is_(None),
)
def list_active_sessions(
session: Session | scoped_session,
ctx: AuthContext,
now: datetime,
) -> list[OAuthAccessToken]:
return list(
session.execute(
select(OAuthAccessToken)
.where(
and_(
*subject_match_clauses(ctx),
OAuthAccessToken.revoked_at.is_(None),
OAuthAccessToken.token_hash.is_not(None),
OAuthAccessToken.expires_at > now,
)
)
.order_by(OAuthAccessToken.created_at.desc())
)
.scalars()
.all()
)
def token_belongs_to_subject(
session: Session | scoped_session,
token_id: str,
ctx: AuthContext,
) -> bool:
row = session.execute(
select(OAuthAccessToken.id).where(
and_(
OAuthAccessToken.id == token_id,
*subject_match_clauses(ctx),
)
)
).first()
return row is not None
def revoke_oauth_token(
session: Session | scoped_session,
redis_client: Any,
token_id: str,
) -> None:
row = (
session.query(OAuthAccessToken.token_hash)
.filter(
OAuthAccessToken.id == token_id,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
pre_revoke_hash = row[0] if row else None
stmt = (
update(OAuthAccessToken)
.where(
OAuthAccessToken.id == token_id,
OAuthAccessToken.revoked_at.is_(None),
)
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
session.execute(stmt)
session.commit()
if pre_revoke_hash:
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))

View File

View File

@ -0,0 +1,52 @@
"""License gate for the /openapi/v1/permitted-external-apps* surface.
EE-only. CE deploys (``ENTERPRISE_ENABLED=false``) skip the gate entirely —
the EE blueprint chain is what gives CE deploys no callers on this surface
in practice, but the explicit short-circuit avoids any test/fixture that
flips the surface on without flipping the license.
Reuses ``FeatureService.get_system_features()`` so the license status
travels the same path as the console reads.
Companion to ``controllers.console.wraps.enterprise_license_required`` —
that one is for console (cookie-authed, force-logout 401). This one is
for bearer surface (token-authed, 403 ``license_required``).
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from functools import wraps
from werkzeug.exceptions import Forbidden
from configs import dify_config
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
_VALID_LICENSE_STATUSES: frozenset[LicenseStatus] = frozenset({LicenseStatus.ACTIVE, LicenseStatus.EXPIRING})
def license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""Decorator form. Raises ``Forbidden('license_required')`` when the EE
deployment has no valid license. No-op on CE (``ENTERPRISE_ENABLED=false``).
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
if dify_config.ENTERPRISE_ENABLED and not _is_license_valid():
raise Forbidden(description="license_required")
return view(*args, **kwargs)
return decorated
def _is_license_valid() -> bool:
try:
features = FeatureService.get_system_features()
except Exception:
logger.exception("license_gate: FeatureService.get_system_features failed")
return False
return features.license.status in _VALID_LICENSE_STATUSES

View File

@ -0,0 +1,47 @@
"""Hard mint policy.
``validate_mint_policy`` cross-checks a (subject_type, prefix, scopes)
triple a caller intends to mint against ``MINTABLE_PROFILES`` —
the single source of truth in ``libs.oauth_bearer``.
The defense-in-depth value: if a future caller assembles ``prefix`` or
``scopes`` from a non-canonical source (env, request body, plug-in
contribution), the mismatch fails closed at approve time before any
row hits the DB. When the caller reads straight from
``MINTABLE_PROFILES``, the check is a structural pin — it confirms the
table entry is well-formed and the caller picked the right key.
"""
from __future__ import annotations
from libs.oauth_bearer import MINTABLE_PROFILES, Scope, SubjectType
class MintPolicyViolation(Exception): # noqa: N818 — spec-defined name, used in BadRequest message
"""Raised on a (subject_type, prefix, scopes) mismatch. Callers translate
to 400 ``mint_policy_violation``."""
def validate_mint_policy(
*,
subject_type: SubjectType,
prefix: str,
scopes: frozenset[Scope],
) -> None:
"""Raise ``MintPolicyViolation`` when the triple does not match the
canonical ``MINTABLE_PROFILES`` entry for ``subject_type``.
"""
profile = MINTABLE_PROFILES.get(subject_type)
if profile is None:
raise MintPolicyViolation(f"mint_policy_violation: unknown subject_type={subject_type!r}")
drift = []
if profile.prefix != prefix:
drift.append(f"prefix got={prefix!r} expected={profile.prefix!r}")
if frozenset(scopes) != profile.scopes:
got = sorted(s.value for s in scopes)
want = sorted(s.value for s in profile.scopes)
drift.append(f"scopes got={got} expected={want}")
if drift:
raise MintPolicyViolation(f"mint_policy_violation: subject_type={subject_type.value}" + "; ".join(drift))

View File

@ -0,0 +1,32 @@
"""Single-source visibility filter for the /openapi/v1/* surface.
Keep every openapi-surface app query routed through ``_apply_openapi_gate``;
retiring or replacing the gate then becomes a one-line change here.
The Service API (/v1/* app-key surface) does NOT use this helper — that
surface has its own per-request guard (``service_api_disabled``) wired
into the legacy ``validate_app_token`` decorator.
"""
from __future__ import annotations
from typing import Any
from models.model import App
def apply_openapi_gate(query: Any) -> Any:
"""Filter a SQLAlchemy Select/Query to apps visible on /openapi/v1/*.
Works with both legacy ``Query.filter`` and 2.0-style ``Select.filter``
(alias of ``.where``).
"""
return query.filter(App.enable_api.is_(True))
def is_openapi_visible(app: App) -> bool:
"""Per-row counterpart for code paths that fetch an App by primary key
(``session.get`` / ``session.scalar``) and need the same visibility check
the query gate would have applied.
"""
return bool(app.enable_api)

View File

@ -22,7 +22,6 @@ from core.helper import marketplace
from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant from models.account import Tenant
@ -30,6 +29,7 @@ from models.model import App, AppMode, AppModelConfig
from models.provider_ids import ModelProviderID, ToolProviderID from models.provider_ids import ModelProviderID, ToolProviderID
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider
from models.workflow import Workflow from models.workflow import Workflow
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -389,19 +389,17 @@ class PluginMigration:
for plugin_id in batch_plugin_ids for plugin_id in batch_plugin_ids
if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"] if plugin_id not in installed_plugins_ids and plugin_id in plugins["plugins"]
] ]
if batch_plugin_identifiers: manager.install_from_identifiers(
manager.install_from_identifiers( tenant_id,
tenant_id, batch_plugin_identifiers,
batch_plugin_identifiers, PluginInstallationSource.Marketplace,
PluginInstallationSource.Marketplace, metas=[
metas=[ {
{ "plugin_unique_identifier": identifier,
"plugin_unique_identifier": identifier, }
} for identifier in batch_plugin_identifiers
for identifier in batch_plugin_identifiers ],
], )
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
with open(extracted_plugins) as f: with open(extracted_plugins) as f:
""" """
@ -597,7 +595,6 @@ class PluginMigration:
for identifier in batch_plugin_identifiers for identifier in batch_plugin_identifiers
], ],
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
except Exception: except Exception:
# add to failed # add to failed
failed.extend(batch_plugin_identifiers) failed.extend(batch_plugin_identifiers)
@ -612,7 +609,6 @@ class PluginMigration:
while not done: while not done:
status = manager.fetch_plugin_installation_task(tenant_id, task_id) status = manager.fetch_plugin_installation_task(tenant_id, task_id)
if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]: if status.status in [PluginInstallTaskStatus.Failed, PluginInstallTaskStatus.Success]:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
for plugin in status.plugins: for plugin in status.plugins:
if plugin.status == PluginInstallTaskStatus.Success: if plugin.status == PluginInstallTaskStatus.Success:
success.append(reverse_map[plugin.plugin_unique_identifier]) success.append(reverse_map[plugin.plugin_unique_identifier])

View File

@ -1,17 +1,8 @@
"""Core plugin service and tenant-scoped plugin metadata cache ownership.
This module owns plugin daemon management calls that are shared by API services
and core runtimes. Plugin model provider discovery is cached here, alongside
plugin install, uninstall, and upgrade invalidation, so all cache mutations for
plugin-owned provider metadata stay tenant-scoped and in one place.
"""
import logging import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from mimetypes import guess_type from mimetypes import guess_type
from pydantic import BaseModel, TypeAdapter, ValidationError from pydantic import BaseModel
from redis import RedisError
from sqlalchemy import delete, select, update from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from yarl import URL from yarl import URL
@ -31,20 +22,16 @@ from core.plugin.entities.plugin import (
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse, PluginDecodeResponse,
PluginInstallTask, PluginInstallTask,
PluginInstallTaskStatus,
PluginListResponse, PluginListResponse,
PluginModelProviderEntity,
PluginVerification, PluginVerification,
) )
from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from models.provider_ids import GenericProviderID, ModelProviderID from models.provider_ids import GenericProviderID
from services.enterprise.plugin_manager_service import ( from services.enterprise.plugin_manager_service import (
PluginManagerService, PluginManagerService,
PreUninstallPluginRequest, PreUninstallPluginRequest,
@ -53,7 +40,6 @@ from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope from services.feature_service import FeatureService, PluginInstallationScope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity])
class PluginService: class PluginService:
@ -67,102 +53,6 @@ class PluginService:
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:" REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes REDIS_TTL = 60 * 5 # 5 minutes
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
@classmethod
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
@staticmethod
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
@classmethod
def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = cls._get_provider_short_name_alias(provider)
return declaration
@classmethod
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
try:
cached_providers = redis_client.get(cache_key)
except (RedisError, RuntimeError):
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
return None
if not cached_providers:
return None
try:
return tuple(_provider_entities_adapter.validate_json(cached_providers))
except (TypeError, ValueError, ValidationError):
logger.warning(
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
)
cls.invalidate_plugin_model_providers_cache(tenant_id)
return None
@classmethod
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
try:
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
except (RedisError, RuntimeError):
logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True)
@classmethod
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
"""Delete the tenant-scoped plugin model provider list cache."""
try:
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
except (RedisError, RuntimeError):
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
@classmethod
def fetch_plugin_model_providers(
cls, *, tenant_id: str, client: PluginModelClient | None = None
) -> Sequence[ProviderEntity]:
"""
Fetch plugin model providers through the tenant-scoped plugin cache.
Plugin daemon provider discovery and plugin lifecycle cache invalidation
are intentionally owned by this service so tenant isolation and cache
expiry are handled in one place.
"""
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
if cached_providers is not None:
return cached_providers
model_client = client or PluginModelClient()
providers = tuple(
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
)
cls._store_cached_plugin_model_providers(tenant_id, providers)
return providers
@staticmethod @staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]: def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
@ -358,18 +248,12 @@ class PluginService:
Fetch plugin installation tasks Fetch plugin installation tasks
""" """
manager = PluginInstaller() manager = PluginInstaller()
tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks):
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return tasks
@staticmethod @staticmethod
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask: def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
manager = PluginInstaller() manager = PluginInstaller()
task = manager.fetch_plugin_installation_task(tenant_id, task_id) return manager.fetch_plugin_installation_task(tenant_id, task_id)
if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return task
@staticmethod @staticmethod
def delete_install_task(tenant_id: str, task_id: str) -> bool: def delete_install_task(tenant_id: str, task_id: str) -> bool:
@ -431,7 +315,7 @@ class PluginService:
# check if the plugin is available to install # check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification) PluginService._check_plugin_installation_scope(response.verification)
result = manager.upgrade_plugin( return manager.upgrade_plugin(
tenant_id, tenant_id,
original_plugin_unique_identifier, original_plugin_unique_identifier,
new_plugin_unique_identifier, new_plugin_unique_identifier,
@ -440,8 +324,6 @@ class PluginService:
"plugin_unique_identifier": new_plugin_unique_identifier, "plugin_unique_identifier": new_plugin_unique_identifier,
}, },
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod @staticmethod
def upgrade_plugin_with_github( def upgrade_plugin_with_github(
@ -457,7 +339,7 @@ class PluginService:
""" """
PluginService._check_marketplace_only_permission() PluginService._check_marketplace_only_permission()
manager = PluginInstaller() manager = PluginInstaller()
result = manager.upgrade_plugin( return manager.upgrade_plugin(
tenant_id, tenant_id,
original_plugin_unique_identifier, original_plugin_unique_identifier,
new_plugin_unique_identifier, new_plugin_unique_identifier,
@ -468,8 +350,6 @@ class PluginService:
"package": package, "package": package,
}, },
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod @staticmethod
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
@ -535,14 +415,12 @@ class PluginService:
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(resp.verification) PluginService._check_plugin_installation_scope(resp.verification)
result = manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,
plugin_unique_identifiers, plugin_unique_identifiers,
PluginInstallationSource.Package, PluginInstallationSource.Package,
[{}], [{}],
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod @staticmethod
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str): def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
@ -556,7 +434,7 @@ class PluginService:
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(plugin_decode_response.verification) PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
result = manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,
[plugin_unique_identifier], [plugin_unique_identifier],
PluginInstallationSource.Github, PluginInstallationSource.Github,
@ -568,8 +446,6 @@ class PluginService:
} }
], ],
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod @staticmethod
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration: def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
@ -637,14 +513,12 @@ class PluginService:
actual_plugin_unique_identifiers.append(response.unique_identifier) actual_plugin_unique_identifiers.append(response.unique_identifier)
metas.append({"plugin_unique_identifier": response.unique_identifier}) metas.append({"plugin_unique_identifier": response.unique_identifier})
result = manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,
actual_plugin_unique_identifiers, actual_plugin_unique_identifiers,
PluginInstallationSource.Marketplace, PluginInstallationSource.Marketplace,
metas, metas,
) )
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod @staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
@ -655,10 +529,7 @@ class PluginService:
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None) plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if not plugin: if not plugin:
result = manager.uninstall(tenant_id, plugin_installation_id) return manager.uninstall(tenant_id, plugin_installation_id)
if result:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
if dify_config.ENTERPRISE_ENABLED: if dify_config.ENTERPRISE_ENABLED:
PluginManagerService.try_pre_uninstall_plugin( PluginManagerService.try_pre_uninstall_plugin(
@ -688,39 +559,37 @@ class PluginService:
if not credential_ids: if not credential_ids:
logger.info("No credentials found for plugin: %s", plugin_id) logger.info("No credentials found for plugin: %s", plugin_id)
else: return manager.uninstall(tenant_id, plugin_installation_id)
provider_ids = session.scalars(
select(Provider.id).where(
Provider.tenant_id == tenant_id,
Provider.provider_name.like(f"{plugin_id}/%"),
Provider.credential_id.in_(credential_ids),
)
).all()
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None)) provider_ids = session.scalars(
select(Provider.id).where(
for provider_id in provider_ids: Provider.tenant_id == tenant_id,
ProviderCredentialsCache( Provider.provider_name.like(f"{plugin_id}/%"),
tenant_id=tenant_id, Provider.credential_id.in_(credential_ids),
identity_id=provider_id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
).delete()
session.execute(
delete(ProviderCredential).where(
ProviderCredential.id.in_(credential_ids),
)
) )
).all()
logger.info( session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
"Completed deleting credentials and cleaning provider associations for plugin: %s",
plugin_id, for provider_id in provider_ids:
ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
).delete()
session.execute(
delete(ProviderCredential).where(
ProviderCredential.id.in_(credential_ids),
) )
)
result = manager.uninstall(tenant_id, plugin_installation_id) logger.info(
if result: "Completed deleting credentials and cleaning provider associations for plugin: %s",
PluginService.invalidate_plugin_model_providers_cache(tenant_id) plugin_id,
return result )
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod @staticmethod
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]: def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:

View File

@ -12,7 +12,6 @@ from sqlalchemy import select
from configs import dify_config from configs import dify_config
from constants import DOCUMENT_EXTENSIONS from constants import DOCUMENT_EXTENSIONS
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
@ -23,6 +22,7 @@ from models.model import UploadFile
from models.workflow import Workflow, WorkflowType from models.workflow import Workflow, WorkflowType
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting
from services.plugin.plugin_migration import PluginMigration from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -13,7 +13,6 @@ from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.plugin_service import PluginService
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ( from core.tools.entities.api_entities import (
@ -32,6 +31,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID from models.provider_ids import ToolProviderID
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -9,7 +9,6 @@ from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
from core.plugin.plugin_service import PluginService
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -28,6 +27,7 @@ from core.tools.utils.encryption import create_provider_encrypter, create_tool_p
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -14,7 +14,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter from core.helper.provider_encryption import ProviderConfigEncrypter, create_provider_encrypter
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.plugin.plugin_service import PluginService
from core.tools.utils.system_encryption import decrypt_system_params from core.tools.utils.system_encryption import decrypt_system_params
from core.trigger.entities.api_entities import ( from core.trigger.entities.api_entities import (
TriggerProviderApiEntity, TriggerProviderApiEntity,
@ -38,6 +37,7 @@ from models.trigger import (
TriggerSubscription, TriggerSubscription,
WorkflowPluginTrigger, WorkflowPluginTrigger,
) )
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -15,7 +15,7 @@ from models import Account, AccountStatus
from models.model import App, EndUser, Site from models.model import App, EndUser, Site
from services.account_service import AccountService from services.account_service import AccountService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import PERMISSION_CHECK_MODES, EnterpriseService, WebAppAccessMode
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task
@ -137,12 +137,8 @@ class WebAppAuthService:
""" """
Check if the app requires permission check based on its access mode. Check if the app requires permission check based on its access mode.
""" """
modes_requiring_permission_check = [
"private",
"private_all",
]
if access_mode: if access_mode:
return access_mode in modes_requiring_permission_check return access_mode in PERMISSION_CHECK_MODES
if not app_code and not app_id: if not app_code and not app_id:
raise ValueError("Either app_code or app_id must be provided.") raise ValueError("Either app_code or app_id must be provided.")
@ -153,7 +149,7 @@ class WebAppAuthService:
raise ValueError("App ID could not be determined from the provided app_code.") raise ValueError("App ID could not be determined from the provided app_code.")
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: if webapp_settings and webapp_settings.access_mode in PERMISSION_CHECK_MODES:
return True return True
return False return False
@ -166,11 +162,11 @@ class WebAppAuthService:
raise ValueError("Either app_code or access_mode must be provided.") raise ValueError("Either app_code or access_mode must be provided.")
if access_mode: if access_mode:
if access_mode == "public": if access_mode == WebAppAccessMode.PUBLIC:
return WebAppAuthType.PUBLIC return WebAppAuthType.PUBLIC
elif access_mode in ["private", "private_all"]: elif access_mode in PERMISSION_CHECK_MODES:
return WebAppAuthType.INTERNAL return WebAppAuthType.INTERNAL
elif access_mode == "sso_verified": elif access_mode == WebAppAccessMode.SSO_VERIFIED:
return WebAppAuthType.EXTERNAL return WebAppAuthType.EXTERNAL
if app_code: if app_code:

View File

@ -6,11 +6,11 @@ from typing import Any, TypedDict
from sqlalchemy import and_, func, or_, select from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.plugin.plugin_service import PluginService
from graphon.enums import WorkflowExecutionStatus from graphon.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
from services.workflow.entities import TriggerMetadata from services.workflow.entities import TriggerMetadata

View File

@ -162,12 +162,18 @@ class _AppRunner:
user = self._resolve_user() user = self._resolve_user()
with self._setup_flask_context(user): with self._setup_flask_context(user):
response = self._run_app( try:
app=app, response = self._run_app(
workflow=workflow, app=app,
user=user, workflow=workflow,
pause_state_config=pause_config, user=user,
) pause_state_config=pause_config,
)
except Exception as exc:
if exec_params.streaming:
_publish_error_event(exc, exec_params.workflow_run_id, exec_params.app_mode)
raise
if not exec_params.streaming: if not exec_params.streaming:
return response return response
@ -238,6 +244,12 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun
return session.get(EndUser, workflow_run.created_by) return session.get(EndUser, workflow_run.created_by)
def _publish_error_event(exc: Exception, workflow_run_id: str, app_mode: AppMode) -> None:
topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id)
payload = json.dumps({"event": "error", "message": str(exc), "status": 500})
topic.publish(payload.encode())
def _publish_streaming_response( def _publish_streaming_response(
response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None],
workflow_run_id: str, workflow_run_id: str,

View File

@ -9,9 +9,9 @@ from celery import shared_task
from core.plugin.entities.marketplace import MarketplacePluginSnapshot from core.plugin.entities.marketplace import MarketplacePluginSnapshot
from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.account import TenantPluginAutoUpgradeStrategy from models.account import TenantPluginAutoUpgradeStrategy
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -0,0 +1,125 @@
"""Shared fixtures for /openapi/v1/* integration tests."""
from __future__ import annotations
import hashlib
import uuid
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
import pytest
from flask import Flask
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
from models.account import AccountStatus
def _sha256(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
@pytest.fixture(autouse=True)
def disable_enterprise(monkeypatch):
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
EE branch override this with their own monkeypatch in-test."""
from configs import dify_config
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
@pytest.fixture
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
with flask_app.app_context():
tenant = Tenant(name="t1", status="normal")
account = Account(email="u@example.com", name="u")
db.session.add_all([tenant, account])
db.session.commit()
account.status = AccountStatus.ACTIVE
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
db.session.add(join)
db.session.commit()
yield account, tenant, join
db.session.delete(join)
db.session.delete(account)
db.session.delete(tenant)
db.session.commit()
@pytest.fixture
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
_, tenant, _ = workspace_account
with flask_app.app_context():
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
db.session.add(app)
db.session.commit()
yield app
db.session.delete(app)
db.session.commit()
@pytest.fixture
def mint_token(flask_app: Flask):
"""Factory fixture; tracks minted rows and deletes them on teardown so
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
minted: list[OAuthAccessToken] = []
def _mint(
token: str,
*,
account_id: str | None,
prefix: str,
subject_email: str,
subject_issuer: str | None,
) -> OAuthAccessToken:
with flask_app.app_context():
row = OAuthAccessToken(
token_hash=_sha256(token),
prefix=prefix,
account_id=account_id,
subject_email=subject_email,
subject_issuer=subject_issuer,
client_id="difyctl",
device_label="test-device",
expires_at=datetime.now(UTC) + timedelta(hours=1),
)
db.session.add(row)
db.session.commit()
minted.append(row)
return row
yield _mint
with flask_app.app_context():
for row in minted:
db.session.delete(db.session.merge(row))
db.session.commit()
@pytest.fixture
def account_token(workspace_account, mint_token) -> str:
account, _, _ = workspace_account
token = "dfoa_" + uuid.uuid4().hex
mint_token(
token,
account_id=account.id,
prefix="dfoa_",
subject_email=account.email,
subject_issuer="dify:account",
)
return token
@pytest.fixture(autouse=True)
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
def _flush():
with flask_app.app_context():
for k in redis_client.keys("auth:*"):
redis_client.delete(k)
for k in redis_client.keys("rl:*"):
redis_client.delete(k)
_flush()
yield
_flush()

View File

@ -0,0 +1,238 @@
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
from __future__ import annotations
import uuid
from collections.abc import Generator
import pytest
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models import App
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
captured = {}
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
captured["mode"] = app_model.mode
captured["args"] = args
captured["invoke_from"] = invoke_from
return {
"event": "message",
"task_id": "t",
"id": "m",
"message_id": "m",
"conversation_id": "c",
"mode": "chat",
"answer": "ok",
"created_at": 0,
}
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.get_json()["mode"] == "chat"
assert captured["mode"] == "chat"
assert captured["invoke_from"] == InvokeFrom.OPENAPI
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
@pytest.fixture
def app_with_mode(flask_app: Flask, workspace_account):
"""Factory that creates an App row in the workspace_account tenant with
a specified mode. Tracks rows for teardown.
"""
_, tenant, _ = workspace_account
created: list[App] = []
def _make(mode: str) -> App:
with flask_app.app_context():
app = App(
tenant_id=tenant.id,
name=f"a-{mode}",
mode=mode,
status="normal",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.commit()
db.session.refresh(app)
db.session.expunge(app)
created.append(app)
return app
yield _make
with flask_app.app_context():
for app in created:
db.session.delete(db.session.merge(app))
db.session.commit()
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"query_required_for_chat" in res.data
def test_run_completion_dispatches_to_completion_handler(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("completion")
captured: dict = {}
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
captured["mode"] = app_model.mode
captured["args"] = args
return {
"event": "message",
"task_id": "t",
"id": "m",
"message_id": "m",
"mode": "completion",
"answer": "ok",
"created_at": 0,
}
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.get_json()["mode"] == "completion"
assert captured["mode"] == "completion"
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("workflow")
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"query_not_supported_for_workflow" in res.data
def test_run_workflow_no_query_dispatches_to_workflow_handler(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("workflow")
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
return {
"workflow_run_id": "wfr",
"task_id": "t",
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
}
monkeypatch.setattr("controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate))
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.get_json()
assert body["mode"] == "workflow"
assert body["workflow_run_id"] == "wfr"
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
app = app_with_mode("channel")
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app.id}/run",
json={"inputs": {}, "response_mode": "blocking"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
assert b"mode_not_runnable" in res.data
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi"},
)
assert res.status_code == 401
def test_run_with_insufficient_scope_returns_403(flask_app, account_token, app_in_workspace, monkeypatch):
"""Stub the authenticator to return an AuthContext with empty scopes."""
from libs import oauth_bearer
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
def _stub_authenticate(self, token: str):
ctx = real_authenticate(self, token)
from dataclasses import replace
return replace(ctx, scopes=frozenset())
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 403
def test_run_with_unknown_app_returns_404(flask_app, account_token):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{uuid.uuid4()}/run",
json={"inputs": {}, "query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 404
def test_run_streaming_returns_event_stream(flask_app, account_token, app_in_workspace, monkeypatch):
def _stream() -> Generator[str, None, None]:
yield 'event: message\ndata: {"x": 1}\n\n'
monkeypatch.setattr(
"controllers.openapi.app_run.AppGenerateService.generate",
staticmethod(lambda **kw: _stream()),
)
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.headers["Content-Type"].startswith("text/event-stream")
assert b"event: message" in res.data
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
client = flask_app.test_client()
res = client.post(
f"/openapi/v1/apps/{app_in_workspace.id}/run",
json={"query": "hi"},
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422

View File

@ -0,0 +1,210 @@
"""Integration tests for /openapi/v1/apps* read surface."""
from __future__ import annotations
from flask.testing import FlaskClient
from models import App
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
resp = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert resp.status_code == 404
def test_apps_describe_returns_merged_shape(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"]["id"] == app_in_workspace.id
assert body["info"]["mode"] == "chat"
assert isinstance(body["parameters"], dict)
def test_apps_describe_full_includes_input_schema(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is not None
assert body["input_schema"] is not None
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
def test_apps_describe_fields_info_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is None
assert body["input_schema"] is None
def test_apps_describe_fields_parameters_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is None
assert body["parameters"] is not None
assert body["input_schema"] is None
def test_apps_describe_fields_input_schema_only(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is None
assert body["parameters"] is None
assert body["input_schema"] is not None
def test_apps_describe_fields_combined(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["info"] is not None
assert body["parameters"] is None
assert body["input_schema"] is not None
def test_apps_describe_fields_unknown_returns_422(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
def test_apps_describe_fields_extra_param_returns_422(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
):
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 422
def test_apps_list_returns_pagination_envelope(
test_client: FlaskClient,
workspace_account,
app_in_workspace: App,
account_token: str,
):
_, tenant, _ = workspace_account
res = test_client.get(
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
body = res.json
assert body["page"] == 1
assert body["limit"] == 20
assert body["total"] >= 1
assert any(d["id"] == app_in_workspace.id for d in body["data"])
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
assert res.status_code == 400
def test_apps_list_tag_no_match_returns_empty_data_not_400(
test_client: FlaskClient,
workspace_account,
app_in_workspace: App,
account_token: str,
):
_, tenant, _ = workspace_account
res = test_client.get(
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json["data"] == []
def test_account_sessions_returns_envelope(
test_client: FlaskClient,
account_token: str,
):
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
assert res.status_code == 200
body = res.json
# canonical envelope shape
assert isinstance(body["data"], list)
assert "page" in body
assert "limit" in body
assert "total" in body
assert "has_more" in body
# the bearer's own minted session must appear
assert any(s["prefix"] == "dfoa_" for s in body["data"])
# legacy "sessions" key must NOT appear
assert "sessions" not in body

View File

@ -0,0 +1,127 @@
"""Integration tests for the /openapi/v1 bearer auth surface.
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
acceptance/rejection on app-scoped routes.
"""
from __future__ import annotations
from collections.abc import Generator
import pytest
from flask import Flask
from flask.testing import FlaskClient
from extensions.ext_database import db
from models import App, Tenant
def test_info_accepts_account_bearer_with_apps_read_scope(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
) -> None:
res = test_client.get(
f"/openapi/v1/apps/{app_in_workspace.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json["id"] == app_in_workspace.id
@pytest.fixture
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
"""A fresh app under a *different* tenant — caller has no membership row."""
with flask_app.app_context():
other_tenant = Tenant(name="other", status="normal")
db.session.add(other_tenant)
db.session.commit()
app = App(
tenant_id=other_tenant.id,
name="b",
mode="chat",
status="normal",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.commit()
yield app
db.session.delete(app)
db.session.delete(other_tenant)
db.session.commit()
def test_layer0_denies_account_bearer_without_membership(
test_client: FlaskClient,
account_token: str,
other_workspace_app: App,
) -> None:
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
res = test_client.get(
f"/openapi/v1/apps/{other_workspace_app.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 403
assert res.json.get("message") == "workspace_membership_revoked"
def test_layer0_skipped_when_enterprise_enabled(
test_client: FlaskClient,
account_token: str,
other_workspace_app: App,
monkeypatch,
) -> None:
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
/info uses validate_bearer + require_workspace_member inline (no
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
gets 200 — gateway is expected to enforce isolation upstream.
"""
from configs import dify_config
# Override the conftest autouse default for this test only.
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
res = test_client.get(
f"/openapi/v1/apps/{other_workspace_app.id}/info",
headers={"Authorization": f"Bearer {account_token}"},
)
assert res.status_code == 200
assert res.json.get("message") != "workspace_membership_revoked"
def test_rate_limit_returns_429_after_60_requests(
test_client: FlaskClient,
account_token: str,
) -> None:
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
headers = {"Authorization": f"Bearer {account_token}"}
for i in range(60):
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 200, f"unexpected fail at i={i}"
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 429
assert r.headers.get("Retry-After"), "Retry-After header missing"
assert int(r.headers["Retry-After"]) >= 1
body = r.json or {}
assert body.get("error") == "rate_limited"
assert isinstance(body.get("retry_after_ms"), int)
assert body["retry_after_ms"] >= 1000
def test_rate_limit_bucket_shared_across_surfaces(
test_client: FlaskClient,
app_in_workspace: App,
account_token: str,
) -> None:
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
headers = {"Authorization": f"Bearer {account_token}"}
for _ in range(30):
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
for _ in range(30):
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
r = test_client.get("/openapi/v1/account", headers=headers)
assert r.status_code == 429

View File

@ -1,4 +1,4 @@
"""Tests for core.plugin.plugin_service.PluginService. """Tests for services.plugin.plugin_service.PluginService.
Covers: version caching with Redis, install permission/scope gates, Covers: version caching with Redis, install permission/scope gates,
icon URL construction, asset retrieval with MIME guessing, plugin icon URL construction, asset retrieval with MIME guessing, plugin
@ -17,11 +17,11 @@ from sqlalchemy.orm import Session
from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.entities.plugin_daemon import PluginVerification from core.plugin.entities.plugin_daemon import PluginVerification
from core.plugin.plugin_service import PluginService
from models import ProviderType from models import ProviderType
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from services.errors.plugin import PluginInstallationForbiddenError from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import PluginInstallationScope from services.feature_service import PluginInstallationScope
from services.plugin.plugin_service import PluginService
def _make_features( def _make_features(
@ -35,8 +35,8 @@ def _make_features(
class TestFetchLatestPluginVersion: class TestFetchLatestPluginVersion:
@patch("core.plugin.plugin_service.marketplace") @patch("services.plugin.plugin_service.marketplace")
@patch("core.plugin.plugin_service.redis_client") @patch("services.plugin.plugin_service.redis_client")
def test_returns_cached_version(self, mock_redis, mock_marketplace): def test_returns_cached_version(self, mock_redis, mock_marketplace):
cached_json = PluginService.LatestPluginCache( cached_json = PluginService.LatestPluginCache(
plugin_id="p1", plugin_id="p1",
@ -53,8 +53,8 @@ class TestFetchLatestPluginVersion:
assert result["p1"].version == "1.0.0" assert result["p1"].version == "1.0.0"
mock_marketplace.batch_fetch_plugin_manifests.assert_not_called() mock_marketplace.batch_fetch_plugin_manifests.assert_not_called()
@patch("core.plugin.plugin_service.marketplace") @patch("services.plugin.plugin_service.marketplace")
@patch("core.plugin.plugin_service.redis_client") @patch("services.plugin.plugin_service.redis_client")
def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace): def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace):
mock_redis.get.return_value = None mock_redis.get.return_value = None
manifest = MagicMock() manifest = MagicMock()
@ -71,8 +71,8 @@ class TestFetchLatestPluginVersion:
assert result["p1"].version == "2.0.0" assert result["p1"].version == "2.0.0"
mock_redis.setex.assert_called_once() mock_redis.setex.assert_called_once()
@patch("core.plugin.plugin_service.marketplace") @patch("services.plugin.plugin_service.marketplace")
@patch("core.plugin.plugin_service.redis_client") @patch("services.plugin.plugin_service.redis_client")
def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace): def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_marketplace.batch_fetch_plugin_manifests.return_value = [] mock_marketplace.batch_fetch_plugin_manifests.return_value = []
@ -81,8 +81,8 @@ class TestFetchLatestPluginVersion:
assert result["unknown"] is None assert result["unknown"] is None
@patch("core.plugin.plugin_service.marketplace") @patch("services.plugin.plugin_service.marketplace")
@patch("core.plugin.plugin_service.redis_client") @patch("services.plugin.plugin_service.redis_client")
def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace): def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace):
mock_redis.get.return_value = None mock_redis.get.return_value = None
mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error") mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error")
@ -93,14 +93,14 @@ class TestFetchLatestPluginVersion:
class TestCheckMarketplaceOnlyPermission: class TestCheckMarketplaceOnlyPermission:
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_raises_when_restricted(self, mock_fs): def test_raises_when_restricted(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True)
with pytest.raises(PluginInstallationForbiddenError): with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_marketplace_only_permission() PluginService._check_marketplace_only_permission()
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_passes_when_not_restricted(self, mock_fs): def test_passes_when_not_restricted(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False)
@ -108,7 +108,7 @@ class TestCheckMarketplaceOnlyPermission:
class TestCheckPluginInstallationScope: class TestCheckPluginInstallationScope:
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_official_only_allows_langgenius(self, mock_fs): def test_official_only_allows_langgenius(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
verification = MagicMock() verification = MagicMock()
@ -116,14 +116,14 @@ class TestCheckPluginInstallationScope:
PluginService._check_plugin_installation_scope(verification) # should not raise PluginService._check_plugin_installation_scope(verification) # should not raise
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_official_only_rejects_third_party(self, mock_fs): def test_official_only_rejects_third_party(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY)
with pytest.raises(PluginInstallationForbiddenError): with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_plugin_installation_scope(None) PluginService._check_plugin_installation_scope(None)
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_official_and_partners_allows_partner(self, mock_fs): def test_official_and_partners_allows_partner(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features( mock_fs.get_system_features.return_value = _make_features(
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
@ -133,7 +133,7 @@ class TestCheckPluginInstallationScope:
PluginService._check_plugin_installation_scope(verification) # should not raise PluginService._check_plugin_installation_scope(verification) # should not raise
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_official_and_partners_rejects_none(self, mock_fs): def test_official_and_partners_rejects_none(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features( mock_fs.get_system_features.return_value = _make_features(
scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS
@ -142,7 +142,7 @@ class TestCheckPluginInstallationScope:
with pytest.raises(PluginInstallationForbiddenError): with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_plugin_installation_scope(None) PluginService._check_plugin_installation_scope(None)
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_none_scope_always_raises(self, mock_fs): def test_none_scope_always_raises(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE)
verification = MagicMock() verification = MagicMock()
@ -151,7 +151,7 @@ class TestCheckPluginInstallationScope:
with pytest.raises(PluginInstallationForbiddenError): with pytest.raises(PluginInstallationForbiddenError):
PluginService._check_plugin_installation_scope(verification) PluginService._check_plugin_installation_scope(verification)
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
def test_all_scope_passes_any(self, mock_fs): def test_all_scope_passes_any(self, mock_fs):
mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL)
@ -159,7 +159,7 @@ class TestCheckPluginInstallationScope:
class TestGetPluginIconUrl: class TestGetPluginIconUrl:
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_constructs_url_with_params(self, mock_config): def test_constructs_url_with_params(self, mock_config):
mock_config.CONSOLE_API_URL = "https://console.example.com" mock_config.CONSOLE_API_URL = "https://console.example.com"
@ -171,7 +171,7 @@ class TestGetPluginIconUrl:
class TestGetAsset: class TestGetAsset:
@patch("core.plugin.plugin_service.PluginAssetManager") @patch("services.plugin.plugin_service.PluginAssetManager")
def test_returns_bytes_and_guessed_mime(self, mock_asset_cls): def test_returns_bytes_and_guessed_mime(self, mock_asset_cls):
mock_asset_cls.return_value.fetch_asset.return_value = b"<svg/>" mock_asset_cls.return_value.fetch_asset.return_value = b"<svg/>"
@ -180,7 +180,7 @@ class TestGetAsset:
assert data == b"<svg/>" assert data == b"<svg/>"
assert "svg" in mime assert "svg" in mime
@patch("core.plugin.plugin_service.PluginAssetManager") @patch("services.plugin.plugin_service.PluginAssetManager")
def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls): def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls):
mock_asset_cls.return_value.fetch_asset.return_value = b"\x00" mock_asset_cls.return_value.fetch_asset.return_value = b"\x00"
@ -190,13 +190,13 @@ class TestGetAsset:
class TestIsPluginVerified: class TestIsPluginVerified:
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_returns_true_when_verified(self, mock_installer_cls): def test_returns_true_when_verified(self, mock_installer_cls):
mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True
assert PluginService.is_plugin_verified("t1", "uid-1") is True assert PluginService.is_plugin_verified("t1", "uid-1") is True
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_returns_false_on_exception(self, mock_installer_cls): def test_returns_false_on_exception(self, mock_installer_cls):
mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found")
@ -204,24 +204,24 @@ class TestIsPluginVerified:
class TestUpgradePluginWithMarketplace: class TestUpgradePluginWithMarketplace:
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_raises_when_marketplace_disabled(self, mock_config): def test_raises_when_marketplace_disabled(self, mock_config):
mock_config.MARKETPLACE_ENABLED = False mock_config.MARKETPLACE_ENABLED = False
with pytest.raises(ValueError, match="marketplace is not enabled"): with pytest.raises(ValueError, match="marketplace is not enabled"):
PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid")
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_raises_when_same_identifier(self, mock_config): def test_raises_when_same_identifier(self, mock_config):
mock_config.MARKETPLACE_ENABLED = True mock_config.MARKETPLACE_ENABLED = True
with pytest.raises(ValueError, match="same plugin"): with pytest.raises(ValueError, match="same plugin"):
PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid") PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid")
@patch("core.plugin.plugin_service.marketplace") @patch("services.plugin.plugin_service.marketplace")
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace):
mock_config.MARKETPLACE_ENABLED = True mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
@ -234,10 +234,10 @@ class TestUpgradePluginWithMarketplace:
mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid") mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid")
installer.upgrade_plugin.assert_called_once() installer.upgrade_plugin.assert_called_once()
@patch("core.plugin.plugin_service.download_plugin_pkg") @patch("services.plugin.plugin_service.download_plugin_pkg")
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download):
mock_config.MARKETPLACE_ENABLED = True mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
@ -256,8 +256,8 @@ class TestUpgradePluginWithMarketplace:
class TestUpgradePluginWithGithub: class TestUpgradePluginWithGithub:
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs):
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
installer = mock_installer_cls.return_value installer = mock_installer_cls.return_value
@ -271,8 +271,8 @@ class TestUpgradePluginWithGithub:
class TestUploadPkg: class TestUploadPkg:
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs):
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
upload_resp = MagicMock() upload_resp = MagicMock()
@ -285,17 +285,17 @@ class TestUploadPkg:
class TestInstallFromMarketplacePkg: class TestInstallFromMarketplacePkg:
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_raises_when_marketplace_disabled(self, mock_config): def test_raises_when_marketplace_disabled(self, mock_config):
mock_config.MARKETPLACE_ENABLED = False mock_config.MARKETPLACE_ENABLED = False
with pytest.raises(ValueError, match="marketplace is not enabled"): with pytest.raises(ValueError, match="marketplace is not enabled"):
PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) PluginService.install_from_marketplace_pkg("t1", ["uid-1"])
@patch("core.plugin.plugin_service.download_plugin_pkg") @patch("services.plugin.plugin_service.download_plugin_pkg")
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download):
mock_config.MARKETPLACE_ENABLED = True mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
@ -315,9 +315,9 @@ class TestInstallFromMarketplacePkg:
call_args = installer.install_from_identifiers.call_args[0] call_args = installer.install_from_identifiers.call_args[0]
assert call_args[1] == ["resolved-uid"] assert call_args[1] == ["resolved-uid"]
@patch("core.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.FeatureService")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
@patch("core.plugin.plugin_service.dify_config") @patch("services.plugin.plugin_service.dify_config")
def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs):
mock_config.MARKETPLACE_ENABLED = True mock_config.MARKETPLACE_ENABLED = True
mock_fs.get_system_features.return_value = _make_features() mock_fs.get_system_features.return_value = _make_features()
@ -336,7 +336,7 @@ class TestInstallFromMarketplacePkg:
class TestUninstall: class TestUninstall:
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls): def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls):
installer = mock_installer_cls.return_value installer = mock_installer_cls.return_value
installer.list_plugins.return_value = [] installer.list_plugins.return_value = []
@ -347,7 +347,7 @@ class TestUninstall:
assert result is True assert result is True
installer.uninstall.assert_called_once_with("t1", "install-1") installer.uninstall.assert_called_once_with("t1", "install-1")
@patch("core.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.PluginInstaller")
def test_cleans_credentials_when_plugin_found( def test_cleans_credentials_when_plugin_found(
self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session self, mock_installer_cls, flask_app_with_containers: Flask, db_session_with_containers: Session
): ):
@ -389,7 +389,7 @@ class TestUninstall:
installer.list_plugins.return_value = [plugin] installer.list_plugins.return_value = [plugin]
installer.uninstall.return_value = True installer.uninstall.return_value = True
with patch("core.plugin.plugin_service.dify_config") as mock_config: with patch("services.plugin.plugin_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False mock_config.ENTERPRISE_ENABLED = False
result = PluginService.uninstall(tenant_id, "install-1") result = PluginService.uninstall(tenant_id, "install-1")

View File

@ -6,7 +6,6 @@ import pytest
from faker import Faker from faker import Faker
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.plugin.plugin_service import PluginService
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -21,6 +20,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
@ -31,7 +31,7 @@ class TestToolTransformService:
def mock_external_service_dependencies(self): def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies.""" """Mock setup for external service dependencies."""
with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config: with patch("services.tools.tools_transform_service.dify_config") as mock_dify_config:
with patch("core.plugin.plugin_service.dify_config", new=mock_dify_config): with patch("services.plugin.plugin_service.dify_config", new=mock_dify_config):
# Setup default mock returns # Setup default mock returns
mock_dify_config.CONSOLE_API_URL = "https://console.example.com" mock_dify_config.CONSOLE_API_URL = "https://console.example.com"

Some files were not shown because too many files have changed in this diff Show More