Compare commits

..

36 Commits

Author SHA1 Message Date
114daf3729 [autofix.ci] apply automated fixes 2026-05-26 10:09:11 +00:00
69fb870946 fix(agent v2): plugin tools builder review fixes
Addresses review comments on the Dify Plugin Tools integration:

* **provider_type default ``"plugin"``** (was ``"builtin"``): Dify Plugin
  Tools live behind ToolProviderType.PLUGIN; the built-in path resolves a
  different provider table and would surface as a misleading
  ``agent_tool_declaration_not_found`` for every plugin tool by default.

* **Thread invoke_from through to ToolManager**: the runtime request
  builder was hardcoding ``InvokeFrom.VALIDATION`` regardless of
  caller. ``WorkflowAgentPluginToolsBuilder.build`` now takes the real
  invoke_from from ``DifyRunContext`` so credential quotas, rate limits,
  and audit tags match the actual call site (DEBUGGER for draft test
  run, SERVICE_API / WEB_APP for published).

* **Narrower exception mapping**: instead of one ``except Exception``
  that always returns ``agent_tool_declaration_not_found``, distinguish:
  - ToolProviderNotFoundError → declaration_not_found
  - ToolProviderCredentialValidationError → credential_invalid
  - bare ValueError (e.g. "runtime not found") → config_invalid

* **Reject non-scalar credential values** instead of ``str(value)``:
  forwarding a nested OAuth dict as its Python ``repr`` to plugin daemon
  would silently misroute auth. Surface ``agent_tool_credential_shape_invalid``
  so the operator fixes the credential schema.

* **AgentSoulDifyToolConfig hardening**:
  - ``extra="ignore"`` (was ``extra="allow"``) so stale fields drop on
    load instead of riding into ``model_dump``.
  - Drop unused ``parameter_overrides`` field.
  - Reject explicit ``name`` overrides until rename UX lands — Stage 3.1
    silently ignored them which is harder to debug than rejecting.

Test additions:
  - invoke_from forwarding for DEBUGGER / SERVICE_API / WEB_APP
  - disabled tools short-circuit (no ToolManager call)
  - plugin_id + provider fallback when provider_id missing
  - unauthorized tool with empty credentials
  - exception → error_code mapping (declaration / credential / config)
  - non-scalar credential rejection
  - legacy provider_name / tool_parameters / credential_id normalization

21 plugin-tools-builder + runtime-request-builder tests pass; ruff clean.
2026-05-26 18:03:09 +08:00
21b6c2bec1 [autofix.ci] apply automated fixes 2026-05-26 09:50:18 +00:00
a41fa5607b add agent backend plugin layer 2026-05-26 17:46:40 +08:00
fb07b43107 feat(api): Node Output Inspector service + 3 REST endpoints (Stage 4 §8) (#36644)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 07:34:33 +00:00
0dad426101 chore: add dependabot to lts branch (#36424) 2026-05-26 07:08:08 +00:00
2a1df4de62 chore(deps): bump boto3 from 1.43.10 to 1.43.14 in /api in the storage group (#36595)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-26 06:47:59 +00:00
2b97f6c8c2 chore: inject tenant id in extension handlers (#36656) 2026-05-26 05:45:03 +00:00
75d6511284 chore: inject account context in file handlers (#36655) 2026-05-26 05:43:57 +00:00
fd059720e5 chore: inject tenant id in feature handlers (#36654) 2026-05-26 05:36:02 +00:00
2a5f7bb1aa chore: inject current user in explore message handlers (#36652) 2026-05-26 05:31:51 +00:00
0f06aa2fdd feat(dify-agent): sync agent progress (#36633)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-26 03:14:10 +00:00
yyh
884e2b864b feat(dify-ui): add textarea primitive (#36547) 2026-05-26 02:33:32 +00:00
a728e0ac69 feat: adding dify cli (#36348)
Co-authored-by: GareArc <garethcxy@dify.ai>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: L1nSn0w <l1nsn0w@qq.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: gigglewang <gigglewang@dify.ai>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Xiyuan Chen <52963600+GareArc@users.noreply.github.com>
2026-05-26 01:12:36 +00:00
7d464d014c fix: remove unused datasource_parameters from Notion pre-import query (#36627) 2026-05-26 01:05:30 +00:00
0ce0127e7e fix(security): reject path traversal sequences before plugin daemon forward (GHSA-gvc6-fh3x-89xh) (#35796)
Co-authored-by: Ido Shani <ido@zafran.io>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-25 16:17:39 +00:00
25da7ae0d9 chore: dep inject for sql session (#36545)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-25 14:24:58 +00:00
4d6f8eba2a fix: normalize summary_index_setting None to fix preview error (#36626) 2026-05-25 13:42:45 +00:00
87268f0662 chore: inject current user in console handlers (#36628) 2026-05-25 13:14:08 +00:00
135e01930b chore: example of current user id dep injection (#36588)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 11:31:40 +00:00
yyh
fe86fa31ec fix: normalize app icon picker dialog state (#36621) 2026-05-25 10:39:52 +00:00
b1f0a11d84 feat: output declaration and inspector (#36618)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 10:08:58 +00:00
fbfb4b3a00 chore: use dify_config.BILLING_ENABLED (#36619) 2026-05-25 09:41:01 +00:00
3a467d1d63 fix: member invite limits with dedup, locking, and accurate new-member counting (#36512) 2026-05-25 08:58:42 +00:00
yyh
23539c5bcc feat(dify-ui): add status and progress primitives (#36615) 2026-05-25 08:31:52 +00:00
9ddd98a265 fix(api): preserve dataset nested null shapes (#36611)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: wangxiaolei <fatelei@gmail.com>
2026-05-25 08:06:33 +00:00
yyh
ecfee2f072 fix: center align slider thumb (#36614) 2026-05-25 07:55:30 +00:00
345ba80942 fix: type mismatches (route says uuid: but handler says str) (#36612) 2026-05-25 07:33:32 +00:00
e617435d03 fix: replace .distinct() with .group_by(Conversation.id) for PostgreSQL JSON compatibility (#36610)
Co-authored-by: cocoon <kuishou68@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 07:15:24 +00:00
5f7eb7bde9 feat: add workflow_version to workflow_agent_node_bindings (#36603)
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 06:26:19 +00:00
yyh
eb41c9b769 chore: upgrade dependencies (#36606) 2026-05-25 05:42:35 +00:00
yyh
8876efb419 refactor(dify-ui): rename toggle group to segmented control (#36605) 2026-05-25 04:57:39 +00:00
adb14d23de feat(dify-agent): add history layer and structural output layer (#36600) 2026-05-25 04:28:17 +00:00
6f1623e02a chore(i18n): sync translations with en-US (#36599)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-05-25 03:06:45 +00:00
67d99723ea fix: External retrieval model response rejects empty score threshold bug (#36577)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 03:01:06 +00:00
639e12a306 fix: request /api/datasets raise exception (#36591)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-25 02:27:54 +00:00
785 changed files with 50005 additions and 5959 deletions

15
.dockerignore Normal file
View File

@ -0,0 +1,15 @@
**/node_modules
**/.pnpm-store
**/dist
**/.next
**/.turbo
**/.cache
**/__pycache__
**/*.pyc
**/.mypy_cache
**/.ruff_cache
.git
.github
*.md
!web/README.md
!api/README.md

4
.gitattributes vendored
View File

@ -5,3 +5,7 @@
# them.
*.sh text eol=lf
# Codegen output must stay byte-identical across platforms so
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
*.generated.ts text eol=lf

4
.github/CODEOWNERS vendored
View File

@ -18,6 +18,10 @@
# Docs
/docs/ @crazywoola
# CLI
/cli/ @langgenius/maintainers
/.github/workflows/cli-tests.yml @langgenius/maintainers
# Backend (default owner, more specific rules below will override)
/api/ @QuantumGhost

111
.github/dependabot.yml vendored
View File

@ -110,3 +110,114 @@ updates:
github-actions-dependencies:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "github-actions"
directory: "/"
target-branch: "lts/1.13.x"
open-pull-requests-limit: 5
schedule:
interval: "weekly"
groups:
github-actions-dependencies:
patterns:
- "*"

88
.github/workflows/cli-release.yml vendored Normal file
View File

@ -0,0 +1,88 @@
name: CLI Release
on:
workflow_dispatch:
push:
tags:
- 'difyctl-v*'
concurrency:
group: cli-release-${{ github.ref }}
cancel-in-progress: true
jobs:
release:
name: build standalone binaries (all targets)
runs-on: depot-ubuntu-24.04
if: github.repository == 'langgenius/dify'
permissions:
contents: write
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
fetch-depth: 0
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Setup Bun
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
with:
bun-version: latest
- name: Read cli/package.json
id: manifest
run: |
version=$(node -p "require('./package.json').version")
channel=$(node -p "require('./package.json').difyctl.channel")
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
{
echo "version=$version"
echo "channel=$channel"
echo "minDify=$minDify"
echo "maxDify=$maxDify"
} >> "$GITHUB_OUTPUT"
- name: Validate manifest
run: scripts/release-validate-manifest.sh
- name: Install cross-arch native prebuilds
# Re-installs node_modules with every @napi-rs/keyring platform variant
# so `bun build --compile` can embed the right .node into each target.
working-directory: ./
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
- name: Compile standalone binaries (all targets)
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
run: |
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
pnpm build:bin
- name: Generate sha256 checksum file
env:
CLI_VERSION: ${{ steps.manifest.outputs.version }}
run: scripts/release-write-checksums.sh
- name: Publish GitHub Release
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
with:
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
name: difyctl ${{ steps.manifest.outputs.version }}
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
generate_release_notes: true
fail_on_unmatched_files: true
files: |
cli/dist/bin/difyctl-v*

60
.github/workflows/cli-smoke.yml vendored Normal file
View File

@ -0,0 +1,60 @@
name: CLI Smoke (live dify)
on:
workflow_dispatch:
inputs:
dify_version:
description: "Dify image tag to test against (e.g. 1.7.0)"
type: string
required: true
cli_ref:
description: "Git ref to build the cli from (default: current branch)"
type: string
required: false
permissions:
contents: read
jobs:
smoke:
runs-on: ubuntu-latest
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout cli ref
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
ref: ${{ inputs.cli_ref || github.ref }}
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Bring up dify
env:
DIFY_VERSION: ${{ inputs.dify_version }}
run: |
cd docker
cp .env.example .env
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
docker compose up -d api worker web db redis
for i in $(seq 1 60); do
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
echo "dify api ready after ${i}s"
break
fi
sleep 1
done
- name: Run smoke against live dify
working-directory: ./cli
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
- name: Dump dify logs on failure
if: failure()
run: |
cd docker
docker compose logs api worker web --tail=200

46
.github/workflows/cli-tests.yml vendored Normal file
View File

@ -0,0 +1,46 @@
name: CLI Tests
on:
workflow_call:
secrets:
CODECOV_TOKEN:
required: false
permissions:
contents: read
concurrency:
group: cli-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: CLI Tests
runs-on: depot-ubuntu-24.04
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
run:
shell: bash
working-directory: ./cli
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: CI pipeline (typecheck, lint, coverage, build)
run: pnpm ci
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: cli/coverage
flags: cli
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}

View File

@ -42,6 +42,7 @@ jobs:
runs-on: depot-ubuntu-24.04
outputs:
api-changed: ${{ steps.changes.outputs.api }}
cli-changed: ${{ steps.changes.outputs.cli }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
@ -62,6 +63,18 @@ jobs:
- 'docker/generate_docker_compose'
- 'docker/ssrf_proxy/**'
- 'docker/volumes/sandbox/conf/**'
cli:
- 'cli/**'
- 'packages/tsconfig/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- 'eslint.config.mjs'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/cli-tests.yml'
- '.github/workflows/cli-docker-build.yml'
- '.github/actions/setup-web/**'
web:
- 'web/**'
- 'packages/**'
@ -184,6 +197,66 @@ jobs:
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
cli-tests-run:
name: Run CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
uses: ./.github/workflows/cli-tests.yml
secrets: inherit
cli-tests-skip:
name: Skip CLI Tests
needs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
runs-on: depot-ubuntu-24.04
steps:
- name: Report skipped CLI tests
run: echo "No CLI-related changes detected; skipping CLI tests."
cli-tests:
name: CLI Tests
if: ${{ always() }}
needs:
- pre_job
- check-changes
- cli-tests-run
- cli-tests-skip
runs-on: depot-ubuntu-24.04
steps:
- name: Finalize CLI Tests status
env:
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
RUN_RESULT: ${{ needs.cli-tests-run.result }}
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
run: |
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
exit 0
fi
if [[ "$TESTS_CHANGED" == 'true' ]]; then
if [[ "$RUN_RESULT" == 'success' ]]; then
echo "CLI tests ran successfully."
exit 0
fi
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
exit 1
fi
if [[ "$SKIP_RESULT" == 'success' ]]; then
echo "CLI tests were skipped because no CLI-related files changed."
exit 0
fi
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
exit 1
web-tests-run:
name: Run Web Tests
needs:

9
.gitignore vendored
View File

@ -115,6 +115,12 @@ venv/
ENV/
env.bak/
venv.bak/
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
!/cli/src/env/
!/cli/src/commands/env/
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
!/cli/scripts/lib/
.conda/
# Spyder project settings
@ -247,8 +253,9 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
*.local.toml
# Code Agent Folder
.qoder/*
.context/*
.context/
.eslintcache

View File

@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
ext_logstore,
ext_mail,
ext_migrate,
ext_oauth_bearer,
ext_orjson,
ext_otel,
ext_proxy_fix,
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
ext_oauth_bearer,
]
for ext in extensions:
short_name = ext.__name__.split(".")[-1]

View File

@ -30,7 +30,8 @@ from clients.agent_backend.factory import create_agent_backend_run_client
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
from clients.agent_backend.request_builder import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_PLUGIN_CONTEXT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_PLUGIN_TOOLS_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendModelConfig,
@ -42,7 +43,8 @@ from clients.agent_backend.request_builder import (
__all__ = [
"AGENT_SOUL_PROMPT_LAYER_ID",
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
"DIFY_PLUGIN_TOOLS_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendError",

View File

@ -4,7 +4,9 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
protocol has a single owner. API-only context such as Agent Soul vs workflow job
prompt is preserved in layer names and metadata until the dedicated product
schemas land in later phases.
schemas land in later phases. Dify-owned execution identifiers are emitted as an
explicit ``dify.execution_context`` layer so the run request stays fully
composition-driven.
"""
from __future__ import annotations
@ -15,18 +17,21 @@ from agenton.compositor import CompositorSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLayerConfig,
DifyPluginLLMLayerConfig,
DifyPluginToolsLayerConfig,
)
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
RunComposition,
RunLayerSpec,
@ -37,17 +42,16 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
tenant_id: str
plugin_id: str
model_provider: str
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
@ -55,10 +59,14 @@ class AgentBackendModelConfig(BaseModel):
class AgentBackendOutputConfig(BaseModel):
"""API-side structured output declaration for the conventional output layer."""
"""API-side structured output declaration for the conventional output layer.
The structured-output tool name is fixed to ``final_output`` inside
``dify_agent.layers.output`` so callers only control the JSON Schema plus
optional description/strictness metadata.
"""
json_schema: dict[str, JsonValue]
name: str = "final_result"
description: str | None = None
strict: bool | None = None
@ -69,13 +77,14 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
model: AgentBackendModelConfig
execution_context: ExecutionContext
execution_context: DifyExecutionContextLayerConfig
workflow_node_job_prompt: str
user_prompt: str
agent_soul_prompt: str | None = None
purpose: RunPurpose = "workflow_node"
idempotency_key: str | None = None
output: AgentBackendOutputConfig | None = None
tools: DifyPluginToolsLayerConfig | None = None
session_snapshot: CompositorSessionSnapshot | None = None
suspend_on_exit: bool = False
metadata: dict[str, JsonValue] = Field(default_factory=dict)
@ -121,21 +130,18 @@ class AgentBackendRunRequestBuilder:
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
type=DIFY_PLUGIN_LAYER_TYPE_ID,
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=DifyPluginLayerConfig(
tenant_id=run_input.model.tenant_id,
plugin_id=run_input.model.plugin_id,
user_id=run_input.model.user_id,
),
config=run_input.execution_context,
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
@ -145,6 +151,17 @@ class AgentBackendRunRequestBuilder:
]
)
if run_input.tools is not None and run_input.tools.tools:
layers.append(
RunLayerSpec(
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=run_input.tools,
)
)
if run_input.output is not None:
layers.append(
RunLayerSpec(
@ -153,7 +170,6 @@ class AgentBackendRunRequestBuilder:
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
name=run_input.output.name,
description=run_input.output.description,
strict=run_input.output.strict,
),
@ -162,7 +178,6 @@ class AgentBackendRunRequestBuilder:
return CreateRunRequest(
composition=RunComposition(layers=layers),
execution_context=run_input.execution_context,
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,

View File

@ -1,3 +1,5 @@
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
@ -23,7 +25,7 @@ class DeploymentConfig(BaseSettings):
default=False,
)
EDITION: str = Field(
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
)

View File

@ -525,6 +525,44 @@ class HttpConfig(BaseSettings):
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
OPENAPI_ENABLED: bool = Field(
description=(
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
"programmatic clients. Set to true to activate; disabled by default."
),
validation_alias=AliasChoices("OPENAPI_ENABLED"),
default=False,
)
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
description=(
"Comma-separated allowlist for /openapi/v1/* CORS. "
"Default empty = same-origin only. Browser-cookie routes within "
"the group reject cross-origin OPTIONS regardless of this list."
),
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
default="",
)
@computed_field
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
description=(
"Comma-separated client_id values accepted at "
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
"without code changes. Unknown client_id returns 400 unsupported_client."
),
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
default="difyctl",
)
@computed_field # type: ignore[misc]
@property
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
)
@ -900,6 +938,17 @@ class AuthConfig(BaseSettings):
default=86400,
)
ENABLE_OAUTH_BEARER: bool = Field(
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
default=True,
)
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
default=60,
)
class ModerationConfig(BaseSettings):
"""
@ -1186,6 +1235,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
default=True,
)
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
description="Days to retain revoked OAuth access-token rows before deletion.",
default=30,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@ -68,6 +68,7 @@ from .app import (
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_node_output_inspector,
workflow_run,
workflow_statistic,
workflow_trigger,
@ -218,6 +219,7 @@ __all__ = [
"workflow_app_log",
"workflow_comment",
"workflow_draft_variable",
"workflow_node_output_inspector",
"workflow_run",
"workflow_statistic",
"workflow_trigger",

View File

@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
from controllers.common.helpers import FileInfo
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.app.wraps import get_app_model, with_session
from controllers.console.workspace.models import LoadBalancingPayload
from controllers.console.wraps import (
account_initialization_required,
@ -26,7 +26,6 @@ from controllers.console.wraps import (
is_admin_or_owner_required,
setup_required,
)
from core.db.session_factory import session_factory
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -852,11 +851,11 @@ class AppTraceApi(Resource):
@setup_required
@login_required
@account_initialization_required
@with_session
@get_app_model
def get(self, app_model):
def get(self, session: Session, app_model: App):
"""Get app trace"""
with session_factory.create_session() as session:
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
return app_trace_config

View File

@ -134,7 +134,7 @@ class CompletionConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
elif args.annotation_status == "not_annotated":
query = (
@ -272,7 +272,7 @@ class ChatConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.distinct()
.group_by(Conversation.id)
)
case "not_annotated":
query = (

View File

@ -417,7 +417,7 @@ class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model, message_id: str):
def get(self, app_model, message_id: UUID):
message_id_str = str(message_id)
message = db.session.scalar(

View File

@ -2,6 +2,7 @@ import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -345,14 +346,15 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: str):
def get(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
return variable
@ -363,7 +365,7 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: str):
def patch(self, app_model: App, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -390,10 +392,11 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
new_name = args_model.name
@ -434,14 +437,15 @@ class VariableApi(Resource):
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
def delete(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
@ -457,7 +461,7 @@ class VariableResetApi(Resource):
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
def put(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -468,10 +472,11 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)

View File

@ -0,0 +1,415 @@
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
with a producer-organized view of each node's declared outputs and their
per-run status. This module exposes two parallel sets of three read-only
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
schedule / plugin triggers). Both sets share the same service code, the same
response shapes, and the same error codes; the URL is the *only* difference,
so the frontend can pick the right prefix based on which run-detail page the
user is on.
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
``published_run_inspector_not_implemented`` 404 code is therefore no longer
produced.
URLs follow the design doc and reuse the existing
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
:mod:`controllers.console.app.workflow_draft_variable`. The
``published`` prefix mirrors it shape-for-shape.
"""
from __future__ import annotations
import json
import logging
from collections.abc import Iterator
from uuid import UUID
from flask import Response
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from libs.exception import BaseHTTPException
from libs.login import login_required
from models import App, AppMode
from services.workflow import inspector_events
from services.workflow.node_output_inspector_service import (
NodeOutputInspectorError,
NodeOutputInspectorService,
)
logger = logging.getLogger(__name__)
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
# intervening proxies (nginx, ingress) don't reap the idle connection.
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
_HEARTBEAT_EVERY_TICKS = 15
# Hard ceiling on a single stream — if we never see a terminal workflow
# event (engine crashed, redis dropped the message), force-close after this
# many ticks (= seconds).
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
def _service() -> NodeOutputInspectorService:
"""One-line factory so tests can monkeypatch a stub if needed."""
return NodeOutputInspectorService()
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
"""Resource-body shared by draft + published snapshot endpoints.
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
behaviour lives here, where unit tests can hit it without spinning up
Flask request context.
"""
try:
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return snapshot.model_dump(mode="json")
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
"""Resource-body shared by draft + published node-detail endpoints."""
try:
view = _service().node_detail(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return view.model_dump(mode="json")
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
"""Resource-body shared by draft + published output-preview endpoints."""
try:
preview = _service().output_preview(
app_model=app_model,
workflow_run_id=str(run_id),
node_id=node_id,
output_name=output_name,
)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
return preview.model_dump(mode="json")
class _InspectorNotFound(BaseHTTPException):
"""404 that preserves the inspector's specific error code.
Without this the response body collapses to a generic ``not_found`` code
and clients lose the ability to distinguish, e.g.,
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
"""
code = 404
def __init__(self, error: NodeOutputInspectorError) -> None:
self.error_code = error.code
super().__init__(description=str(error))
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
class WorkflowDraftRunNodeOutputsApi(Resource):
"""Whole-run snapshot organized by producer node."""
@console_ns.doc("get_workflow_draft_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowDraftRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status."""
@console_ns.doc("get_workflow_draft_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output (with signed URL for file refs)."""
@console_ns.doc("get_workflow_draft_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
# ──────────────────────────────────────────────────────────────────────────────
# SSE event stream — shared generator used by draft + published variants
# ──────────────────────────────────────────────────────────────────────────────
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
``data`` is JSON-serialized when given as a dict; raw strings are
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
"""Yield SSE-framed strings for one workflow run.
The stream begins with a full ``snapshot`` event so the client has a
starting state without needing a separate REST GET. Then for every
``node_changed`` message from the pub/sub channel we re-read that node
from DB and push a fresh ``node_changed`` event. When the workflow run
reaches a terminal state we push one final ``workflow_run_completed``
event and close the stream.
Failures inside the loop are caught and surfaced as ``error`` events so
the frontend can show a banner rather than seeing the connection drop
silently. The Inspector never raises across the SSE boundary.
"""
service = _service()
run_id_str = str(run_id)
# Initial snapshot — also flushes a 404 back at the client right away
# if the run is gone (raised before yielding any bytes, so Flask turns it
# into the normal HTTP 404 path).
try:
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
except NodeOutputInspectorError as error:
raise _InspectorNotFound(error) from error
event_id = 0
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
# If the run already finished by the time the client connected, emit
# the terminal envelope synchronously and close — no point subscribing.
# The enum value for partial success is the hyphenated ``partial-succeeded``
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
event_id,
)
return
# Live subscription
ticks_since_heartbeat = 0
total_ticks = 0
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
total_ticks += 1
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
logger.warning(
"Inspector SSE: forcing close after %ds without terminal event for run %s",
_STREAM_HARD_TIMEOUT_TICKS,
run_id_str,
)
return
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
# ``node_changed`` message with both fields ``None`` on every redis
# timeout. Real ``workflow_completed`` messages keep their kind even
# when status couldn't be resolved (publisher race), so checking kind
# first makes the heartbeat branch safe.
if message.kind == "node_changed" and message.node_id is None and message.status is None:
ticks_since_heartbeat += 1
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
yield ":keepalive\n\n"
ticks_since_heartbeat = 0
continue
ticks_since_heartbeat = 0
if message.kind == "workflow_completed":
event_id += 1
yield _sse_envelope(
"workflow_run_completed",
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
event_id,
)
return
# node_changed: recompute the node slice from DB
if not message.node_id:
continue
try:
node_view = service.node_detail(
app_model=app_model,
workflow_run_id=run_id_str,
node_id=message.node_id,
)
except NodeOutputInspectorError:
# Node may not appear in the graph yet (race with persistence); skip.
continue
except Exception:
logger.warning(
"Inspector SSE: node_detail failed for run %s node %s",
run_id_str,
message.node_id,
exc_info=True,
)
event_id += 1
yield _sse_envelope(
"error",
{"node_id": message.node_id, "message": "failed to refresh node detail"},
event_id,
)
continue
event_id += 1
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
class WorkflowDraftRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a draft run."""
@console_ns.doc("stream_workflow_draft_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
# ──────────────────────────────────────────────────────────────────────────────
# Published-run endpoints — symmetric to the draft trio above
# ──────────────────────────────────────────────────────────────────────────────
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
class WorkflowPublishedRunNodeOutputsApi(Resource):
"""Whole-run snapshot for a *published* workflow run.
Same response shape as the ``/draft/`` variant — frontend can multiplex
based on which page (Composer test-run vs. Run History) is mounted.
"""
@console_ns.doc("get_workflow_published_run_node_outputs")
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return _serve_snapshot(app_model, run_id)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
"""One node's declared outputs + per-output status (published run)."""
@console_ns.doc("get_workflow_published_run_node_output_detail")
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
}
)
@console_ns.response(404, "Workflow run / node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str):
return _serve_node_detail(app_model, run_id, node_id)
@console_ns.route(
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
"/node-outputs/<string:node_id>/<string:output_name>/preview"
)
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
"""Full value for one declared output of a published run."""
@console_ns.doc("get_workflow_published_run_node_output_preview")
@console_ns.doc(description="Full value for one declared output of a published run.")
@console_ns.doc(
params={
"app_id": "Application ID",
"run_id": "Workflow run ID",
"node_id": "Node ID inside the workflow graph",
"output_name": "Declared output name as exposed by Composer",
}
)
@console_ns.response(404, "Workflow run / node / output not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
return _serve_output_preview(app_model, run_id, node_id, output_name)
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
"""SSE stream of inspector deltas for a published run."""
@console_ns.doc("stream_workflow_published_run_node_output_events")
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
return Response(
_stream_inspector_events(app_model, run_id),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: str):
def get(self, app_model: App, run_id: UUID):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)

View File

@ -1,16 +1,38 @@
"""Controller decorators for console app resources.
`with_session` opens one SQLAlchemy session for a request handler and injects it
as the first argument after `self`. Handlers use a transaction by default so
migrated write paths keep commit/rollback handling; pure read handlers may opt
out with `write=False`. App-loading decorators prefer that injected session when
present, while still supporting existing handlers that have not been migrated
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
"""
from collections.abc import Callable
from functools import wraps
from typing import overload
from typing import Concatenate, cast, overload
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console.app.error import AppNotFoundError
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
def _load_app_model(app_id: str) -> App | None:
def _load_app_model(session: Session, app_id: str) -> App | None:
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
_, current_tenant_id = current_account_with_tenant()
app_model = session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
return app_model
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
_, current_tenant_id = current_account_with_tenant()
app_model = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R],
*,
write: bool = True,
) -> Callable[Concatenate[T, P], R]: ...
@overload
def with_session[T, **P, R](
view: None = None,
*,
write: bool = True,
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
def with_session[T, **P, R](
view: Callable[Concatenate[T, Session, P], R] | None = None,
*,
write: bool = True,
) -> (
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
):
"""Inject a request-scoped session, using a transaction only for write handlers."""
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
if write:
with session_factory.get_session_maker().begin() as session:
return view(self, session, *args, **kwargs)
with session_factory.create_session() as session:
return view(self, session, *args, **kwargs)
return wrapper
if view is None:
return decorator
return decorator(view)
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
if len(args) < 2:
return None
candidate = args[1]
if isinstance(candidate, Session):
return candidate
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
return cast(Session, candidate)
return None
@overload
def get_app_model[**P, R](
view: Callable[P, R],
@ -44,6 +123,13 @@ def get_app_model[**P, R](
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Inject the App model for handlers that receive an `app_id` path parameter.
New handlers may compose `@with_session` above this decorator so the app row
is loaded through the same request-scoped session used by the controller.
Existing handlers continue to work through `db.session` until migrated.
"""
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@ -55,7 +141,11 @@ def get_app_model[**P, R](
del kwargs["app_id"]
app_model = _load_app_model(app_id)
session = _get_injected_session(args)
if session is None:
app_model = _load_app_model_from_scoped_session(app_id)
else:
app_model = _load_app_model(session, app_id)
if not app_model:
raise AppNotFoundError()

View File

@ -48,7 +48,6 @@ class NotionEstimatePayload(BaseModel):
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
class DataSourceNotionPreviewQuery(BaseModel):
@ -205,9 +204,6 @@ class DataSourceNotionListApi(Resource):
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters = query.datasource_parameters or {}
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
@ -255,7 +251,7 @@ class DataSourceNotionListApi(Resource):
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters=datasource_parameters,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
)

View File

@ -979,7 +979,7 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@ -996,7 +996,7 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: str):
def post(self, dataset_id: UUID):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})

View File

@ -10,7 +10,12 @@ from controllers.common.fields import UsageCountResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
)
from fields.dataset_fields import (
dataset_detail_fields,
dataset_retrieval_model_fields,
@ -126,9 +131,9 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.response(200, "External API templates retrieved successfully")
@setup_required
@login_required
@with_current_tenant_id
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
def get(self, current_tenant_id: str):
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
@ -168,21 +169,22 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
def get(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
return variable
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
def patch(self, pipeline: Pipeline, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -210,11 +212,12 @@ class RagPipelineVariableApi(Resource):
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
@ -250,15 +253,16 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
def delete(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -267,7 +271,7 @@ class RagPipelineVariableApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
def put(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -278,11 +282,12 @@ class RagPipelineVariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()

View File

@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: str):
def get(self, pipeline: Pipeline, run_id: UUID):
"""
Get workflow run node execution list
"""

View File

@ -21,13 +21,14 @@ from controllers.console.explore.error import (
NotCompletionAppError,
)
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models import Account
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@ -59,8 +60,8 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
)
class MessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
@ -96,8 +97,8 @@ class MessageListApi(InstalledAppResource):
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
def post(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, installed_app, message_id: UUID):
app_model = installed_app.app
message_id_str = str(message_id)
@ -124,8 +125,8 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app, message_id: UUID):
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -170,8 +171,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
def get(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app, message_id: UUID):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -9,9 +9,10 @@ from controllers.common.schema import register_response_schema_models, register_
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.login import current_account_with_tenant
from models import Account
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -22,8 +23,8 @@ register_response_schema_models(console_ns, ResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -46,8 +47,8 @@ class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -67,8 +68,8 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
@console_ns.response(204, "Saved message deleted successfully")
def delete(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, installed_app, message_id: UUID):
app_model = installed_app.app
message_id_str = str(message_id)

View File

@ -13,6 +13,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()

View File

@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
from . import console_ns
from .wraps import account_initialization_required, setup_required
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
class CodeBasedExtensionQuery(BaseModel):
@ -116,11 +116,11 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
]
@console_ns.doc("create_api_based_extension")
@ -130,9 +130,9 @@ class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension(
tenant_id=current_tenant_id,
@ -153,12 +153,12 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@ -169,9 +169,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, id: UUID):
@with_current_tenant_id
def post(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
@ -197,9 +197,9 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, id: UUID):
@with_current_tenant_id
def delete(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)

View File

@ -2,11 +2,11 @@ from flask_restx import Resource
from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_account_with_tenant, current_user, login_required
from libs.login import current_user, login_required
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
@ -24,10 +24,9 @@ class FeatureApi(Resource):
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
@ -49,10 +48,9 @@ class FeatureVectorSpaceApi(Resource):
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""Get vector-space usage and limit for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_vector_space(current_tenant_id).model_dump()

View File

@ -22,10 +22,13 @@ from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.file_service import FileService
from . import console_ns
@ -62,8 +65,8 @@ class FileApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("documents")
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
@ -107,10 +110,10 @@ class FilePreviewApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, file_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, file_id: UUID):
file_id_str = str(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
return {"content": text}

View File

@ -8,8 +8,14 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_user,
)
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
@ -70,11 +76,10 @@ class NotificationApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
def get(self, current_user: Account):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
@ -113,11 +118,11 @@ class NotificationDismissApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account):
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,

View File

@ -12,11 +12,13 @@ from controllers.common.errors import (
)
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import with_current_user
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models.account import Account
from services.file_service import FileService
@ -49,7 +51,8 @@ class RemoteFileUpload(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
@login_required
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
@ -74,12 +77,11 @@ class RemoteFileUpload(Resource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
user=current_user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:

View File

@ -4,6 +4,7 @@ from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
import services
from configs import dify_config
@ -22,15 +23,15 @@ from controllers.console.auth.error import (
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
from models.account import Account, TenantAccountJoin, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@ -79,6 +80,54 @@ def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
return list(dict.fromkeys(email.lower() for email in emails))
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
new_member_count = 0
for email in emails:
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
new_member_count += 1
continue
exists = db.session.scalar(
select(TenantAccountJoin.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not exists:
new_member_count += 1
return new_member_count
def _count_current_members(tenant_id: str) -> int:
return (
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
)
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
if new_member_count <= 0:
return
features = FeatureService.get_features(tenant_id=tenant_id)
if dify_config.ENTERPRISE_ENABLED:
workspace_members = features.workspace_members
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
raise WorkspaceMembersLimitExceeded()
return
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
members = features.members
current_member_count = _count_current_members(tenant_id)
if 0 < members.limit < current_member_count + new_member_count:
raise WorkspaceMembersLimitExceeded()
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@ -105,12 +154,11 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args.emails
invitee_emails = _normalize_invitee_emails(args.emails)
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
@ -130,37 +178,36 @@ class MemberInviteEmailApi(Resource):
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
tenant_id = inviter.current_tenant.id
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
_check_member_invite_limits(tenant_id, new_member_count)
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
return {
"result": "success",

View File

@ -4,6 +4,7 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from flask import abort, request
from sqlalchemy import select
@ -16,6 +17,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models import Account
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
@ -82,9 +84,7 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
if not dify_config.BILLING_ENABLED:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -198,15 +198,11 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
utm_info = request.cookies.get("utm_info")
if dify_config.BILLING_ENABLED and utm_info:
_, current_tenant_id = current_account_with_tenant()
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -309,7 +305,6 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
@ -327,7 +322,6 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
@ -495,3 +489,25 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return view(*args, **kwargs)
return decorated
def with_current_tenant_id[T, **P, R](
view: Callable[Concatenate[T, str, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
_, current_tenant_id = current_account_with_tenant()
return view(self, current_tenant_id, *args, **kwargs)
return decorated
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated

View File

@ -0,0 +1,128 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.device_flow_security import attach_anti_framing
from libs.external_api import ExternalApi
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
attach_anti_framing(bp)
api = ExternalApi(
bp,
version="1.0",
title="OpenAPI",
description="User-scoped programmatic API (bearer auth)",
)
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
# Register response/query models BEFORE importing controller modules so that
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.openapi._models import (
AccountPayload,
AccountResponse,
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppInfoResponse,
AppListQuery,
AppListResponse,
AppListRow,
AppRunRequest,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
MessageMetadata,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
RevokeResponse,
ServerVersionResponse,
SessionListResponse,
SessionRow,
TagItem,
UsageInfo,
WorkflowRunData,
WorkspaceDetailResponse,
WorkspaceListResponse,
WorkspacePayload,
WorkspaceSummaryResponse,
)
from fields.file_fields import FileResponse
register_schema_models(
openapi_ns,
AppDescribeQuery,
AppListQuery,
AppRunRequest,
DeviceCodeRequest,
DevicePollRequest,
DeviceLookupQuery,
DeviceMutateRequest,
PermittedExternalAppsListQuery,
)
register_response_schema_models(
openapi_ns,
TagItem,
UsageInfo,
MessageMetadata,
AppListRow,
AppListResponse,
AppInfoResponse,
AppDescribeInfo,
AppDescribeResponse,
WorkflowRunData,
AccountPayload,
WorkspacePayload,
AccountResponse,
SessionRow,
SessionListResponse,
PermittedExternalAppsListResponse,
RevokeResponse,
WorkspaceSummaryResponse,
WorkspaceListResponse,
WorkspaceDetailResponse,
DeviceCodeResponse,
DeviceLookupResponse,
DeviceMutateResponse,
FileResponse,
ServerVersionResponse,
)
from . import (
_meta,
account,
app_run,
apps,
apps_permitted_external,
files,
human_input_form,
index,
oauth_device,
oauth_device_sso,
workflow_events,
workspaces,
)
# Request models are imported from _models.py and registered above.
__all__ = [
"_meta",
"account",
"app_run",
"apps",
"apps_permitted_external",
"files",
"human_input_form",
"index",
"oauth_device",
"oauth_device_sso",
"workflow_events",
"workspaces",
]
api.add_namespace(openapi_ns)

View File

@ -0,0 +1,66 @@
"""Audit emission for openapi app-run endpoints.
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
matches the existing oauth_device convention. The EE OTel exporter consults
its own allowlist to decide whether to ship the line.
"""
from __future__ import annotations
import logging
logger = logging.getLogger(__name__)
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
def emit_app_run(
*,
app_id: str,
tenant_id: str,
caller_kind: str,
mode: str,
surface: str,
) -> None:
logger.info(
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
EVENT_APP_RUN_OPENAPI,
app_id,
tenant_id,
caller_kind,
mode,
surface,
extra={
"audit": True,
"event": EVENT_APP_RUN_OPENAPI,
"app_id": app_id,
"tenant_id": tenant_id,
"caller_kind": caller_kind,
"mode": mode,
"surface": surface,
},
)
def emit_wrong_surface(
*,
subject_type: str | None,
attempted_path: str,
client_id: str | None,
token_id: str | None,
) -> None:
logger.warning(
"audit: %s subject_type=%s attempted_path=%s",
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
subject_type,
attempted_path,
extra={
"audit": True,
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
"subject_type": subject_type,
"attempted_path": attempted_path,
"client_id": client_id,
"token_id": token_id,
},
)

View File

@ -0,0 +1,143 @@
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
from __future__ import annotations
from typing import Any, cast
from controllers.service_api.app.error import AppUnavailableError
from models import App
from models.model import AppMode
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": {},
"required": [],
}
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
def _file_object_shape() -> dict[str, Any]:
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
return {
"type": "object",
"properties": {
"type": {"type": "string"},
"transfer_method": {"type": "string"},
"url": {"type": "string"},
"upload_file_id": {"type": "string"},
},
"additionalProperties": True,
}
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
label = row.get("label") or row.get("variable", "")
base: dict[str, Any] = {"title": label} if label else {}
if row_type in ("text-input", "paragraph"):
out: dict[str, Any] = {"type": "string"} | base
max_length = row.get("max_length")
if isinstance(max_length, int) and max_length > 0:
out["maxLength"] = max_length
return out
if row_type == "select":
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
if row_type == "number":
return {"type": "number"} | base
if row_type == "file":
return _file_object_shape() | base
if row_type == "file-list":
return {
"type": "array",
"items": _file_object_shape(),
} | base
return None
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
"""Translate a user_input_form row list into (properties, required-list).
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
Unknown variable types are skipped (forward-compat).
"""
properties: dict[str, Any] = {}
required: list[str] = []
for row in form:
if not isinstance(row, dict) or len(row) != 1:
continue
((row_type, row_body),) = row.items()
if not isinstance(row_body, dict):
continue
variable = row_body.get("variable")
if not variable:
continue
schema = _row_to_schema(row_type, row_body)
if schema is None:
continue
properties[variable] = schema
if row_body.get("required"):
required.append(variable)
return properties, required
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
Raises `AppUnavailableError` on misconfigured apps.
"""
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
return (
workflow.features_dict,
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
)
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = cast(dict[str, Any], app_model_config.to_dict())
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
def build_input_schema(app: App) -> dict[str, Any]:
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
completion / workflow: `inputs` object only.
Raises `AppUnavailableError` on misconfigured apps.
"""
_, user_input_form = resolve_app_config(app)
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
properties: dict[str, Any] = {}
required: list[str] = []
if app.mode in _CHAT_FAMILY:
properties["query"] = {"type": "string", "minLength": 1}
required.append("query")
properties["inputs"] = {
"type": "object",
"properties": inputs_props,
"required": inputs_required,
"additionalProperties": False,
}
required.append("inputs")
return {
"$schema": JSON_SCHEMA_DRAFT,
"type": "object",
"properties": properties,
"required": required,
}

View File

@ -0,0 +1,23 @@
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
Returns the server's project version and edition so the difyctl CLI can probe
compatibility without needing to be logged in. Mirrors the `_health` endpoint
in `index.py`.
"""
from flask_restx import Resource
from configs import dify_config
from controllers.openapi import openapi_ns
from controllers.openapi._models import ServerVersionResponse
@openapi_ns.route("/_version")
class VersionApi(Resource):
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
def get(self):
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
return ServerVersionResponse(
version=dify_config.project.version,
edition=edition,
).model_dump(mode="json")

View File

@ -0,0 +1,344 @@
"""Shared response substructures for openapi endpoints."""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from libs.helper import UUIDStrOrEmpty, uuid_value
from models.model import AppMode
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
MAX_PAGE_LIMIT = 200
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class MessageMetadata(BaseModel):
usage: UsageInfo | None = None
retriever_resources: list[dict[str, Any]] = []
class PaginationEnvelope[T](BaseModel):
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
page: int
limit: int
total: int
has_more: bool
data: list[T]
@classmethod
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
class TagItem(BaseModel):
name: str
class AppListRow(BaseModel):
id: str
name: str
description: str | None = None
mode: AppMode
tags: list[TagItem] = []
updated_at: str | None = None
created_by_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
class AppListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class PermittedExternalAppsListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[AppListRow]
class AppInfoResponse(BaseModel):
id: str
name: str
description: str | None = None
mode: str
author: str | None = None
tags: list[TagItem] = []
class AppDescribeInfo(AppInfoResponse):
updated_at: str | None = None
service_api_enabled: bool
is_agent: bool = False
class AppDescribeResponse(BaseModel):
info: AppDescribeInfo | None = None
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
class ChatMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
conversation_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class CompletionMessageResponse(BaseModel):
event: str
task_id: str
id: str
message_id: str
mode: str
answer: str
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
created_at: int
class WorkflowRunData(BaseModel):
id: str
workflow_id: str
status: str
outputs: dict[str, Any] = Field(default_factory=dict)
error: str | None = None
elapsed_time: float | None = None
total_tokens: int | None = None
total_steps: int | None = None
created_at: int | None = None
finished_at: int | None = None
class WorkflowRunResponse(BaseModel):
workflow_run_id: str
task_id: str
mode: Literal["workflow"] = "workflow"
data: WorkflowRunData
class AccountPayload(BaseModel):
id: str
email: str
name: str
class WorkspacePayload(BaseModel):
id: str
name: str
role: str
class AccountResponse(BaseModel):
subject_type: str
subject_email: str | None = None
subject_issuer: str | None = None
account: AccountPayload | None = None
workspaces: list[WorkspacePayload] = []
default_workspace_id: str | None = None
class SessionRow(BaseModel):
id: str
prefix: str
client_id: str
device_label: str
created_at: str | None = None
last_used_at: str | None = None
expires_at: str | None = None
class SessionListResponse(BaseModel):
page: int
limit: int
total: int
has_more: bool
data: list[SessionRow]
class RevokeResponse(BaseModel):
status: str
class WorkspaceSummaryResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
class WorkspaceListResponse(BaseModel):
workspaces: list[WorkspaceSummaryResponse]
class WorkspaceDetailResponse(BaseModel):
id: str
name: str
role: str
status: str
current: bool
created_at: str | None = None
class DeviceCodeResponse(BaseModel):
device_code: str
user_code: str
verification_uri: str
expires_in: int
interval: int
class DeviceLookupResponse(BaseModel):
valid: bool
expires_in_remaining: int = 0
client_id: str | None = None
class DeviceMutateResponse(BaseModel):
status: str
class ServerVersionResponse(BaseModel):
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
version: str
edition: Literal["SELF_HOSTED", "CLOUD"]
class AppDescribeQuery(BaseModel):
"""`?fields=` allow-list for GET /apps/<id>/describe.
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
"""
model_config = ConfigDict(extra="forbid")
fields: set[str] | None = None
workspace_id: str | None = None
@field_validator("workspace_id", mode="before")
@classmethod
def _validate_workspace_id(cls, v: object) -> str | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("workspace_id must be a string")
try:
import uuid as _uuid
_uuid.UUID(v)
except ValueError:
raise ValueError("workspace_id must be a valid UUID")
return v
@field_validator("fields", mode="before")
@classmethod
def _parse_fields(cls, v: object) -> set[str] | None:
if v is None or v == "":
return None
if not isinstance(v, str):
raise ValueError("fields must be a comma-separated string")
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
members = {m.strip() for m in v.split(",") if m.strip()}
unknown = members - _ALLOWED_DESCRIBE_FIELDS
if unknown:
raise ValueError(f"unknown field(s): {sorted(unknown)}")
return members
class AppListQuery(BaseModel):
"""mode is a closed enum."""
workspace_id: str
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
tag: str | None = Field(None, max_length=100)
class AppRunRequest(BaseModel):
inputs: dict[str, Any]
query: str | None = None
files: list[dict[str, Any]] | None = None
conversation_id: UUIDStrOrEmpty | None = None
auto_generate_name: bool = True
workflow_id: str | None = None
workspace_id: UUIDStrOrEmpty | None = None
@field_validator("conversation_id", mode="before")
@classmethod
def _normalize_conv(cls, value: str | None) -> str | None:
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("conversation_id must be a valid UUID") from exc
class DeviceCodeRequest(BaseModel):
client_id: str
device_label: str
class DevicePollRequest(BaseModel):
device_code: str
client_id: str
class DeviceLookupQuery(BaseModel):
user_code: str
class DeviceMutateRequest(BaseModel):
user_code: str
class PermittedExternalAppsListQuery(BaseModel):
"""Strict (extra='forbid')."""
model_config = ConfigDict(extra="forbid")
page: int = Field(1, ge=1)
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
class ExtSubjectAssertionClaims(BaseModel):
email: str = _EMAIL_FIELD
issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
class ApprovalGrantClaimsPayload(BaseModel):
subject_email: str = _EMAIL_FIELD
subject_issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
csrf_token: str = Field(min_length=1, max_length=128)

View File

@ -0,0 +1,169 @@
from __future__ import annotations
from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
MAX_PAGE_LIMIT,
AccountPayload,
AccountResponse,
PaginationEnvelope,
RevokeResponse,
SessionListResponse,
SessionRow,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from services.account_service import AccountService, TenantService
from services.oauth_device_flow import (
list_active_sessions,
revoke_oauth_token,
token_belongs_to_subject,
)
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email,
subject_issuer=ctx.subject_issuer,
account=None,
workspaces=[],
default_workspace_id=None,
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email or (account.email if account else None),
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
).model_dump(mode="json")
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
all_rows = list_active_sessions(db.session, ctx, now)
total = len(all_rows)
sliced = all_rows[(page - 1) * limit : page * limit]
items = [
SessionRow(
id=str(r.id),
prefix=r.prefix,
client_id=r.client_id,
device_label=r.device_label,
created_at=_iso(r.created_at),
last_used_at=_iso(r.last_used_at),
expires_at=_iso(r.expires_at),
)
for r in sliced
]
return (
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
200,
)
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
# 404 (not 403) on cross-subject so the endpoint doesn't leak
# token IDs that belong to other subjects.
if not token_belongs_to_subject(db.session, session_id, ctx):
raise NotFound("session not found")
revoke_oauth_token(db.session, redis_client, session_id)
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt.isoformat().replace("+00:00", "Z")
def _pick_default_workspace(memberships) -> str | None:
if not memberships:
return None
for join, tenant in memberships:
if getattr(join, "current", False):
return str(tenant.id)
return str(memberships[0][1].id)
def _workspace_payload(row) -> WorkspacePayload:
join, tenant = row
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
def _account_payload(account) -> AccountPayload:
return AccountPayload(id=str(account.id), email=account.email, name=account.name)

View File

@ -0,0 +1,165 @@
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
from __future__ import annotations
import logging
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import AppRunRequest
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
ConversationCompletedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.oauth_bearer import Scope
from models.model import App, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import (
IsDraftWorkflowError,
WorkflowIdFormatError,
WorkflowNotFoundError,
)
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
@contextmanager
def _translate_service_errors() -> Iterator[None]:
try:
yield
except WorkflowNotFoundError as ex:
raise NotFound(str(ex))
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
raise BadRequest(str(ex))
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except InvokeError as e:
raise CompletionRequestError(e.description)
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
return AppGenerateService.generate(
app_model=app,
user=caller,
args=args,
invoke_from=InvokeFrom.OPENAPI,
streaming=streaming,
)
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
if not payload.query or not payload.query.strip():
raise UnprocessableEntity("query_required_for_chat")
args = payload.model_dump(exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
args = payload.model_dump(exclude_none=True)
args["auto_generate_name"] = False
args.setdefault("query", "")
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
if payload.query is not None:
raise UnprocessableEntity("query_not_supported_for_workflow")
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
with _translate_service_errors():
return _generate(app, caller, args, streaming=True)
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
AppMode.CHAT: _run_chat,
AppMode.AGENT_CHAT: _run_chat,
AppMode.ADVANCED_CHAT: _run_chat,
AppMode.COMPLETION: _run_completion,
AppMode.WORKFLOW: _run_workflow,
}
@openapi_ns.route("/apps/<string:app_id>/run")
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
handler = _DISPATCH.get(app_model.mode)
if handler is None:
raise UnprocessableEntity("mode_not_runnable")
try:
stream_obj = handler(app_model, caller, payload)
except HTTPException:
raise
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
emit_app_run(
app_id=app_model.id,
tenant_id=app_model.tenant_id,
caller_kind=caller_kind,
mode=str(app_model.mode),
surface="apps",
)
return helper.compact_generate_response(stream_obj)
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -0,0 +1,270 @@
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last → outermost → publishes the auth ContextVar before `require_scope`
reads it.
"""
from __future__ import annotations
import uuid as _uuid
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
from controllers.common.fields import Parameters
from controllers.common.schema import query_params_from_model
from controllers.openapi import openapi_ns
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
from controllers.openapi._models import (
AppDescribeInfo,
AppDescribeQuery,
AppDescribeResponse,
AppListQuery,
AppListResponse,
AppListRow,
TagItem,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
get_auth_ctx,
require_scope,
require_workspace_member,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
_EMPTY_PARAMETERS: dict[str, Any] = {
"opening_statement": None,
"suggested_questions": [],
"user_input_form": [],
"file_upload": None,
"system_parameters": {},
}
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = get_auth_ctx()
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
except ValueError:
parsed_uuid = None
is_uuid = False
if is_uuid:
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None:
raise NotFound("app not found")
else:
if not workspace_id:
raise UnprocessableEntity("workspace_id is required for name-based lookup")
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
if len(matches) == 0:
raise NotFound("app not found")
if len(matches) > 1:
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
for m in matches:
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
raise Conflict("".join(lines))
app = matches[0]
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
def parameters_payload(app: App) -> dict:
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
features_dict, user_input_form = resolve_app_config(app)
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
@openapi_ns.route("/apps/<string:app_id>/describe")
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
def get(self, app_id: str):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app, _ = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
want_params = requested is None or "parameters" in requested
want_schema = requested is None or "input_schema" in requested
info = (
AppDescribeInfo(
id=str(app.id),
name=app.name,
mode=app.mode,
description=app.description,
tags=[TagItem(name=t.name) for t in app.tags],
author=app.author_name,
updated_at=app.updated_at.isoformat() if app.updated_at else None,
service_api_enabled=bool(app.enable_api),
is_agent=app.mode in ("agent-chat", "advanced-chat"),
)
if want_info
else None
)
parameters: dict[str, Any] | None = None
input_schema: dict[str, Any] | None = None
if want_params:
try:
parameters = parameters_payload(app)
except AppUnavailableError:
parameters = dict(_EMPTY_PARAMETERS)
if want_schema:
try:
input_schema = build_input_schema(app)
except AppUnavailableError:
input_schema = dict(EMPTY_INPUT_SCHEMA)
return (
AppDescribeResponse(
info=info,
parameters=parameters,
input_schema=input_schema,
).model_dump(mode="json", exclude_none=False),
200,
)
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
def get(self):
ctx: AuthContext = get_auth_ctx()
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
mode="json"
),
200,
)
if query.name:
try:
parsed_uuid = _uuid.UUID(query.name)
except ValueError:
parsed_uuid = None
else:
parsed_uuid = None
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
if app is None or str(app.tenant_id) != workspace_id:
return empty
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
item = AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[TagItem(name=t.name) for t in app.tags],
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=getattr(app, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
return env.model_dump(mode="json"), 200
tag_ids: list[str] | None = None
if query.tag:
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
if not tags:
return empty
tag_ids = [tag.id for tag in tags]
params = AppListParams(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else "all", # type:ignore
name=query.name,
tag_ids=tag_ids,
status="normal",
# Visibility gate pushed into the query — pagination.total stays
# consistent across pages because invisible rows never count.
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
if pagination is None:
return empty
tenant_name = None
if pagination.items:
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
items = [
AppListRow(
id=str(r.id),
name=r.name,
description=r.description,
mode=r.mode,
tags=[TagItem(name=t.name) for t in r.tags],
updated_at=r.updated_at.isoformat() if r.updated_at else None,
created_by_name=getattr(r, "author_name", None),
workspace_id=str(workspace_id),
workspace_name=tenant_name,
)
for r in pagination.items
]
env = AppListResponse(
page=query.page,
limit=query.limit,
total=cast(int, pagination.total),
has_more=query.page * query.limit < cast(int, pagination.total),
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,102 @@
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
(public / sso_verified). License-gated: CE deploys never enable the
EE blueprint chain so this module is unreachable there.
"""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from pydantic import ValidationError
from werkzeug.exceptions import UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AppListRow,
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
def get(self):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
page_result = list_permitted_apps(
page=query.page,
limit=query.limit,
mode=query.mode.value if query.mode else None,
name=query.name,
)
if not page_result.app_ids:
env = PermittedExternalAppsListResponse(
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
)
return env.model_dump(mode="json"), 200
apps_by_id: dict[str, App] = {
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
}
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
items: list[AppListRow] = []
for app_id in page_result.app_ids:
app = apps_by_id.get(app_id)
if not app or app.status != "normal":
continue
tenant = tenants_by_id.get(str(app.tenant_id))
items.append(
AppListRow(
id=str(app.id),
name=app.name,
description=app.description,
mode=app.mode,
tags=[], # tenant-scoped; not surfaced cross-tenant
updated_at=app.updated_at.isoformat() if app.updated_at else None,
created_by_name=None, # cross-tenant author leak prevention
workspace_id=str(app.tenant_id),
workspace_name=tenant.name if tenant else None,
)
)
env = PermittedExternalAppsListResponse(
page=query.page,
limit=query.limit,
total=page_result.total,
has_more=query.page * query.limit < page_result.total,
data=items,
)
return env.model_dump(mode="json"), 200

View File

@ -0,0 +1,3 @@
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
__all__ = ["OAUTH_BEARER_PIPELINE"]

View File

@ -0,0 +1,46 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline — they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
)

View File

@ -0,0 +1,68 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context — handlers
read populated values via the decorator's kwargs unpacking.
Context is intentionally decoupled from Flask's ``Request``: the pipeline
guard extracts whatever transport-level inputs the steps need (bearer
token, path params) at the boundary and writes them into Context fields,
so steps stay testable without a request object and won't leak coupling
to a specific framework.
"""
from __future__ import annotations
import uuid
from collections.abc import Mapping
from contextvars import Token
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from werkzeug.exceptions import Unauthorized
from libs.oauth_bearer import AuthContext, Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
required_scope: Scope
bearer_token: str | None = None
path_params: Mapping[str, str] = field(default_factory=dict)
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
auth_ctx_reset_token: Token[AuthContext] | None = None
@property
def must_tenant(self) -> Tenant:
if not self.tenant:
raise Unauthorized("tenant is not associated")
return self.tenant
@property
def must_subject_type(self) -> SubjectType:
if not self.subject_type:
raise Unauthorized("subject_type unset — BearerCheck did not run")
return self.subject_type
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -0,0 +1,51 @@
"""Pipeline IS the auth scheme.
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
"""
from __future__ import annotations
from functools import wraps
from flask import request
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
def guard(self, *, scope: Scope):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
# Extract transport-level inputs at the boundary so steps
# stay decoupled from Flask's request object.
ctx = Context(
required_scope=scope,
bearer_token=extract_bearer(request),
path_params=dict(request.view_args or {}),
)
try:
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
finally:
if ctx.auth_ctx_reset_token is not None:
reset_auth_ctx(ctx.auth_ctx_reset_token)
return decorated
return decorator

View File

@ -0,0 +1,170 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also publishes the
resolved identity to the openapi auth ``ContextVar`` (the same one the
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
surface gate and any handler reading the request-scoped context has a single
source of truth across both auth-attach paths. The reset token is stashed
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
its `finally` so worker-thread reuse can't leak identity across requests.
"""
from __future__ import annotations
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
check_workspace_membership,
get_authenticator,
set_auth_ctx,
)
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also publishes the resolved `AuthContext` via
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
``validate_bearer`` writes — so the surface gate + downstream readers
don't see two different identity sources. The reset token is parked on
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
def __call__(self, ctx: Context) -> None:
if not ctx.bearer_token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(ctx.bearer_token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if the resolved subject is not in `accepted`."""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route — that is the design lock-in (no body /
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
``ctx.path_params`` at the boundary so this step doesn't need to know
about the request object.
"""
def __call__(self, ctx: Context) -> None:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only — SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.must_tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.must_subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -0,0 +1,168 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny — no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # type:ignore
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -0,0 +1,89 @@
"""Surface gate.
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
step) is the pipeline-level form. Both delegate to `check_surface` so the
audit emit + canonical-path message are single-sourced.
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
``openapi.wrong_surface_denied``.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from flask import request
from werkzeug.exceptions import Forbidden
from controllers.openapi._audit import emit_wrong_surface
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
_CANONICAL_PATH: dict[SubjectType, str] = {
SubjectType.ACCOUNT: "/openapi/v1/apps",
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
}
F = TypeVar("F", bound=Callable[..., object])
def check_surface(accepted: frozenset[SubjectType]) -> None:
"""Enforce that the resolved subject is in ``accepted``.
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
set the bearer layer didn't run — that's a wiring bug, not a
user-driven failure, so surface it as a ``RuntimeError`` instead of
a silent 403.
"""
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
)
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
if subject in accepted:
return
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
emit_wrong_surface(
subject_type=subject.value if subject else None,
attempted_path=request.path,
client_id=getattr(ctx, "client_id", None),
token_id=_stringify(getattr(ctx, "token_id", None)),
)
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
accepted_set: frozenset[SubjectType] = frozenset(accepted)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
check_surface(accepted_set)
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco
def _coerce_subject_type(raw: object) -> SubjectType | None:
if raw is None:
return None
if isinstance(raw, SubjectType):
return raw
if isinstance(raw, str):
return SubjectType(raw)
return None
def _stringify(value: object) -> str | None:
if value is None:
return None
return str(value)

View File

@ -0,0 +1,72 @@
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
from __future__ import annotations
from flask import request
from flask_restx import Resource
from flask_restx.api import HTTPStatus
from werkzeug.exceptions import BadRequest
import services
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from extensions.ext_database import db
from fields.file_fields import FileResponse
from libs.oauth_bearer import Scope
from models import Account, App
from services.file_service import FileService
@openapi_ns.route("/apps/<string:app_id>/files/upload")
class AppFileUploadApi(Resource):
@openapi_ns.doc("upload_file_for_app_input")
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
@openapi_ns.doc(
responses={
201: "File uploaded successfully",
400: "Bad request — no file or filename missing",
401: "Unauthorized — invalid or expired bearer token",
413: "File too large",
415: "Unsupported file type or blocked extension",
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
file = request.files["file"]
if not file.mimetype:
raise UnsupportedFileTypeError()
if not file.filename:
raise FilenameNotExistsError()
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.stream.read(),
mimetype=file.mimetype,
user=caller,
)
except ValueError as exc:
raise BadRequest(str(exc))
except services.errors.file.FileTooLargeError as exc:
raise FileTooLargeError(exc.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
except services.errors.file.BlockedFileExtensionError as exc:
raise BlockedFileExtensionError(exc.description)
response = FileResponse.model_validate(upload_file, from_attributes=True)
return response.model_dump(mode="json"), 201

View File

@ -0,0 +1,107 @@
"""
OpenAPI bearer-authed human input form endpoints.
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
"""
from __future__ import annotations
import json
import logging
from flask import Response, request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
from libs.oauth_bearer import Scope
from models.model import App
from services.human_input_service import FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
def _jsonify_form_definition(form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_openapi(form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
raise NotFound("Form not found")
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_openapi(form)
submission_user_id: str | None = None
submission_end_user_id: str | None = None
if caller_kind == "account":
submission_user_id = caller.id
else:
submission_end_user_id = caller.id
if form.recipient_type is None:
logger.warning("Recipient type is None for form, form_token=%s", form_token)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=form.recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=submission_user_id,
submission_end_user_id=submission_end_user_id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -0,0 +1,9 @@
from flask_restx import Resource
from controllers.openapi import openapi_ns
@openapi_ns.route("/_health")
class HealthApi(Resource):
def get(self):
return {"ok": True}

View File

@ -0,0 +1,398 @@
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
sub-groups in one module:
Protocol (RFC 8628, public + rate-limited):
POST /oauth/device/code
POST /oauth/device/token
GET /oauth/device/lookup
Approval (account branch, console-cookie authed):
POST /oauth/device/approve
POST /oauth/device/deny
SSO branch lives in oauth_device_sso.py.
"""
from __future__ import annotations
import logging
from typing import Any
from flask import request
from flask_login import login_required
from flask_restx import Resource
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from configs import dify_config
from controllers.common.schema import query_params_from_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
AccountPayload,
DeviceCodeRequest,
DeviceCodeResponse,
DeviceLookupQuery,
DeviceLookupResponse,
DeviceMutateRequest,
DeviceMutateResponse,
DevicePollRequest,
WorkspacePayload,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
from libs.rate_limit import (
LIMIT_APPROVE_CONSOLE,
LIMIT_DEVICE_CODE_PER_IP,
LIMIT_LOOKUP_PUBLIC,
rate_limit,
)
from services.account_service import TenantService
from services.oauth_device_flow import (
ACCOUNT_ISSUER_SENTINEL,
DEFAULT_POLL_INTERVAL_SECONDS,
DEVICE_FLOW_TTL_SECONDS,
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
SlowDownDecision,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# =========================================================================
# Validation helpers
# =========================================================================
def _validate_json[M: BaseModel](model: type[M]) -> M:
body = request.get_json(silent=True) or {}
try:
return model.model_validate(body)
except ValidationError as exc:
raise BadRequest(str(exc))
def _validate_query[M: BaseModel](model: type[M]) -> M:
try:
return model.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise BadRequest(str(exc))
# =========================================================================
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
# =========================================================================
@openapi_ns.route("/oauth/device/code")
class OAuthDeviceCodeApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
def post(self):
payload = _validate_json(DeviceCodeRequest)
client_id = payload.client_id
device_label = payload.device_label
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
return {"error": "unsupported_client"}, 400
store = DeviceFlowRedis(redis_client)
ip = extract_remote_ip(request)
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
return {
"device_code": device_code,
"user_code": user_code,
"verification_uri": _verification_uri(),
"expires_in": expires_in,
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
}, 200
@openapi_ns.route("/oauth/device/token")
class OAuthDeviceTokenApi(Resource):
"""RFC 8628 poll."""
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
def post(self):
payload = _validate_json(DevicePollRequest)
device_code = payload.device_code
store = DeviceFlowRedis(redis_client)
# slow_down beats every other branch — polling-too-fast clients
# see only that response regardless of underlying state.
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
return {"error": "slow_down"}, 400
state = store.load_by_device_code(device_code)
if state is None:
return {"error": "expired_token"}, 400
if state.status is DeviceFlowStatus.PENDING:
return {"error": "authorization_pending"}, 400
terminal = store.consume_on_poll(device_code)
if terminal is None:
return {"error": "expired_token"}, 400
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
_audit_cross_ip_if_needed(state)
return poll_payload, 200
@openapi_ns.route("/oauth/device/lookup")
class OAuthDeviceLookupApi(Resource):
"""Read-only — public for pre-validate before login. user_code is
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
"""
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
@rate_limit(LIMIT_LOOKUP_PUBLIC)
def get(self):
payload = _validate_query(DeviceLookupQuery)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
_device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
return {
"valid": True,
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
"client_id": state.client_id,
}, 200
# =========================================================================
# Approval endpoints — account branch (cookie-authed)
# =========================================================================
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
_APPROVE_GUARD_TTL_SECONDS = 10
@openapi_ns.route("/oauth/device/approve")
class DeviceApproveApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
account, tenant = current_account_with_tenant()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
# SET NX guard — without it, two in-flight approves both pass
# PENDING, both mint, and the second upsert silently rotates the
# first caller into an already-revoked token.
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
return {"error": "approve_in_progress"}, 409
try:
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=tenant)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=account.email,
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
account_id=str(account.id),
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
poll_payload = _build_account_poll_payload(account, tenant, mint)
try:
store.approve(
device_code,
subject_email=account.email,
account_id=str(account.id),
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError):
# Row minted but state vanished — roll forward; the orphan
# token is revocable via auth devices list / Authorized Apps.
logger.exception("device_flow: approve raced on %s", device_code)
return {"error": "state_lost"}, 409
finally:
redis_client.delete(guard_key)
_emit_approve_audit(state, account, tenant, mint)
return {"status": "approved"}, 200
@openapi_ns.route("/oauth/device/deny")
class DeviceDenyApi(Resource):
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@bearer_feature_required
@rate_limit(LIMIT_APPROVE_CONSOLE)
def post(self):
payload = _validate_json(DeviceMutateRequest)
user_code = payload.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
return {"error": "expired_or_unknown"}, 404
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
return {"error": "already_resolved"}, 409
try:
store.deny(device_code)
except (StateNotFoundError, InvalidTransitionError):
logger.exception("device_flow: deny raced on %s", device_code)
return {"error": "state_lost"}, 409
_emit_deny_audit(state)
return {"status": "denied"}, 200
# =========================================================================
# Helpers
# =========================================================================
def _verification_uri() -> str:
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
if base:
return f"{base.rstrip('/')}/device"
return f"{request.host_url.rstrip('/')}/device"
def _audit_cross_ip_if_needed(state) -> None:
poll_ip = extract_remote_ip(request)
if state.created_ip and poll_ip and poll_ip != state.created_ip:
logger.warning(
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
state.token_id,
state.created_ip,
poll_ip,
extra={
"audit": True,
"token_id": state.token_id,
"creation_ip": state.created_ip,
"poll_ip": poll_ip,
},
)
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
# Prefer active session tenant → DB-flagged current join → first membership.
default_ws_id = None
if tenant and any(w.id == str(tenant) for w in workspaces):
default_ws_id = str(tenant)
if default_ws_id is None:
for _t, m in rows:
if getattr(m, "current", False):
default_ws_id = str(m.tenant_id)
break
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0].id
payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
"workspaces": [w.model_dump(mode="json") for w in workspaces],
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
return payload
def _emit_approve_audit(state, account, tenant, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
mint.token_id,
account.email,
state.client_id,
state.device_label,
mint.expires_at,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"token_id": str(mint.token_id),
"subject_type": SubjectType.ACCOUNT,
"subject_email": account.email,
"account_id": str(account.id),
"tenant_id": tenant,
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["full"],
"expires_at": mint.expires_at.isoformat(),
},
)
def _emit_deny_audit(state) -> None:
logger.warning(
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
state.client_id,
state.device_label,
extra={
"audit": True,
"event": "oauth.device_flow_denied",
"client_id": state.client_id,
"device_label": state.device_label,
},
)

View File

@ -0,0 +1,365 @@
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
EE-only. Browser flow:
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
Function-based (raw @bp.route) rather than Resource classes because the
handlers do redirects + cookie kwargs that don't fit the Resource shape.
"""
from __future__ import annotations
import logging
import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from pydantic import ValidationError
from werkzeug.exceptions import (
BadGateway,
BadRequest,
Conflict,
Forbidden,
NotFound,
Unauthorized,
)
from configs import dify_config
from controllers.openapi import bp
from controllers.openapi._models import ExtSubjectAssertionClaims
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import jws
from libs.device_flow_security import (
APPROVAL_GRANT_COOKIE_NAME,
ApprovalGrantClaims,
approval_grant_cleared_cookie_kwargs,
approval_grant_cookie_kwargs,
consume_approval_grant_nonce,
consume_sso_assertion_nonce,
enterprise_only,
mint_approval_grant,
verify_approval_grant,
)
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
from libs.rate_limit import (
LIMIT_APPROVE_EXT_PER_EMAIL,
LIMIT_SSO_INITIATE_PER_IP,
enforce,
rate_limit,
)
from services.account_service import AccountService
from services.enterprise.enterprise_service import EnterpriseService
from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
)
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
logger = logging.getLogger(__name__)
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
# device_code it references.
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
def _trusted_origin() -> str:
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
if not base:
raise BadGateway("console_api_url_unset")
return base
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
def sso_initiate():
user_code = (request.args.get("user_code") or "").strip().upper()
if not user_code:
raise BadRequest("user_code required")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise BadRequest("invalid_user_code")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
origin = _trusted_origin()
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
payload={
"redirect_url": "",
"app_code": "",
"intent": "device_flow",
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
)
try:
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
except Exception as e:
logger.warning("sso-initiate: enterprise call failed: %s", e)
raise BadGateway("sso_initiate_failed") from e
url = (reply or {}).get("url")
if not url:
raise BadGateway("sso_initiate_missing_url")
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
resp = redirect(url, code=302)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@bp.route("/oauth/device/sso-complete", methods=["GET"])
@enterprise_only
def sso_complete():
blob = request.args.get("sso_assertion")
if not blob:
raise BadRequest("sso_assertion required")
keyset = jws.KeySet.from_shared_secret()
try:
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
try:
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
except ValidationError as e:
logger.warning("sso-complete: claim shape invalid: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
raise BadRequest("invalid_sso_assertion")
user_code = claims.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
raise Conflict("user_code_not_pending")
_, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims.email):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
iss = _trusted_origin()
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims.email,
subject_issuer=claims.issuer,
user_code=user_code,
)
resp = redirect("/device?sso_verified=1", code=302)
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
return resp
@bp.route("/oauth/device/approval-context", methods=["GET"])
@enterprise_only
def approval_context():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("no_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approval-context: bad cookie: %s", e)
raise Unauthorized("no_session") from e
return jsonify(
{
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"user_code": claims.user_code,
"csrf_token": claims.csrf_token,
"expires_at": claims.expires_at.isoformat(),
}
), 200
@bp.route("/oauth/device/approve-external", methods=["POST"])
@enterprise_only
def approve_external():
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
if not token:
raise Unauthorized("invalid_session")
keyset = jws.KeySet.from_shared_secret()
try:
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
except jws.VerifyError as e:
logger.warning("approve-external: bad cookie: %s", e)
raise Unauthorized("invalid_session") from e
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}
body_user_code = (data.get("user_code") or "").strip().upper()
if body_user_code != claims.user_code:
raise BadRequest("user_code_mismatch")
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(claims.user_code)
if found is None:
raise NotFound("user_code_not_pending")
device_code, state = found
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
raise Forbidden("email_belongs_to_dify_account")
if not consume_approval_grant_nonce(redis_client, claims.nonce):
raise Unauthorized("session_already_consumed")
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
try:
validate_mint_policy(
subject_type=profile.subject_type,
prefix=profile.prefix,
scopes=profile.scopes,
)
except MintPolicyViolation as e:
raise BadRequest(description=str(e)) from None
ttl_days = oauth_ttl_days(tenant_id=None)
mint = mint_oauth_token(
db.session,
redis_client,
subject_email=claims.subject_email,
subject_issuer=claims.subject_issuer,
account_id=None,
client_id=state.client_id,
device_label=state.device_label,
prefix=profile.prefix,
ttl_days=ttl_days,
)
# SSO branch of the shared PollPayload contract: account/workspace
# fields are zero-filled (`None` / `[]`) for parity with the account
# branch in `oauth_device._build_account_poll_payload`.
poll_payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"account": None,
"workspaces": [],
"default_workspace_id": None,
"token_id": str(mint.token_id),
}
try:
store.approve(
device_code,
subject_email=claims.subject_email,
account_id=None,
subject_issuer=claims.subject_issuer,
minted_token=mint.token,
token_id=str(mint.token_id),
poll_payload=poll_payload,
)
except (StateNotFoundError, InvalidTransitionError) as e:
logger.exception("approve-external: state transition raced")
raise Conflict("state_lost") from e
_emit_approve_external_audit(state, claims, mint)
resp = make_response(jsonify({"status": "approved"}), 200)
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
return resp
@dataclass(frozen=True)
class _RejectedClaims:
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
event without reaching for the full dataclass.
"""
subject_email: str
subject_issuer: str
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
logger.warning(
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
reason,
extra={
"audit": True,
"event": "oauth.device_flow_rejected",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"reason": reason,
"client_id": state.client_id,
"device_label": state.device_label,
},
)
def _emit_approve_external_audit(state, claims, mint) -> None:
logger.warning(
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
SubjectType.EXTERNAL_SSO,
claims.subject_email,
claims.subject_issuer,
mint.token_id,
extra={
"audit": True,
"event": "oauth.device_flow_approved",
"subject_type": SubjectType.EXTERNAL_SSO,
"subject_email": claims.subject_email,
"subject_issuer": claims.subject_issuer,
"token_id": str(mint.token_id),
"client_id": state.client_id,
"device_label": state.device_label,
"scopes": ["apps:run"],
"expires_at": mint.expires_at.isoformat(),
},
)

View File

@ -0,0 +1,119 @@
"""
OpenAPI bearer-authed workflow reconnect event stream endpoint.
GET /apps/<app_id>/tasks/<task_id>/events
— reconnect to the SSE stream for a paused/running workflow run.
`task_id` is treated as `workflow_run_id`.
"""
from __future__ import annotations
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from libs.oauth_bearer import Scope
from models.enums import CreatorUserRole
from models.model import App, AppMode
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.response(200, "SSE event stream")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if caller_kind == "account":
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
else:
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=caller,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
else:
generator = WorkflowAppGenerator()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.OPENAPI,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)

View File

@ -0,0 +1,78 @@
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
counterparts to the cookie-authed /console/api/workspaces endpoints.
Account bearers (dfoa_) see every tenant they're a member of. External
SSO bearers (dfoe_) have no account_id and so see an empty list — that
matches /openapi/v1/account.
"""
from __future__ import annotations
from itertools import starmap
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
from controllers.openapi.auth.surface_gate import accept_subjects
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from models import Tenant, TenantAccountJoin
from services.account_service import TenantService
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = get_auth_ctx()
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = get_auth_ctx()
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")
tenant, membership = row
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
return WorkspaceSummaryResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
)
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
return WorkspaceDetailResponse(
id=str(tenant.id),
name=tenant.name,
role=getattr(membership, "role", ""),
status=tenant.status,
current=getattr(membership, "current", False),
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
)

View File

@ -174,11 +174,11 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
def put(self, app_model: App, annotation_id: UUID):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -195,7 +195,7 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def delete(self, app_model: App, annotation_id: str):
def delete(self, app_model: App, annotation_id: UUID):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
return "", 204

View File

@ -1,5 +1,6 @@
import logging
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -50,20 +51,20 @@ class FilePreviewApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: str):
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
"""
Preview/Download a file that was uploaded via Service API.
Provides secure file preview/download functionality.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
file_id = str(file_id)
file_id_str = str(file_id)
# Parse query parameters
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
# Get file content generator
try:

View File

@ -1,5 +1,6 @@
from collections.abc import Generator
from typing import Any
from uuid import UUID
from flask import request
from pydantic import BaseModel
@ -64,10 +65,11 @@ class DatasourcePluginsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: str):
def get(self, tenant_id: str, dataset_id: UUID):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -77,7 +79,7 @@ class DatasourcePluginsApi(DatasetApiResource):
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
)
return datasource_plugins, 200
@ -109,10 +111,11 @@ class DatasourceNodeRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
def post(self, tenant_id: str, dataset_id: UUID, node_id: str):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -120,7 +123,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
@ -172,10 +175,11 @@ class PipelineRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
def post(self, tenant_id: str, dataset_id: UUID):
"""Resource for running a rag pipeline."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -186,7 +190,7 @@ class PipelineRunApi(DatasetApiResource):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,

View File

@ -1,4 +1,5 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal
@ -107,17 +108,19 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
@ -150,7 +153,10 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
return {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
}, 200
else:
return {"error": "Segments is required"}, 400
@ -165,19 +171,21 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset_id_str = str(dataset_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
# check embedding model setting
@ -205,7 +213,7 @@ class SegmentApi(DatasetApiResource):
)
segments, total = SegmentService.get_segments(
document_id=document_id,
document_id=document_id_str,
tenant_id=current_tenant_id,
status_list=args.status,
keyword=args.keyword,
@ -214,7 +222,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": _marshal_segments_with_summary(segments, dataset_id),
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -240,22 +248,25 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@ -276,18 +287,20 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -306,15 +319,19 @@ class DatasetSegmentApi(DatasetApiResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
return {
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
"doc_form": document.doc_form,
}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@ -325,26 +342,29 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
@service_api_ns.route(
@ -369,23 +389,26 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -429,23 +452,26 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -461,7 +487,9 @@ class ChildChunkApi(DatasetApiResource):
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
@ -497,32 +525,38 @@ class DatasetChildChunkApi(DatasetApiResource):
)
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id):
if str(segment.document_id) != str(document_id_str):
raise NotFound("Document not found.")
child_chunk_id_str = str(child_chunk_id)
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -558,32 +592,38 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# get document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id):
if str(segment.document_id) != str(document_id_str):
raise NotFound("Segment not found.")
child_chunk_id_str = str(child_chunk_id)
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -16,7 +16,7 @@ from libs.passport import PassportService
from libs.token import extract_webapp_passport
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if not webapp_settings:
raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
_validate_user_accessibility(
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
raise Unauthorized("Please re-login to access the web app.")
app_id = AppService.get_app_id_by_code(app_code)
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
!= WebAppAccessMode.PUBLIC
)
if app_web_auth_enabled:
raise WebAppAuthRequiredError()

View File

@ -22,9 +22,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
@ -150,44 +147,9 @@ class BaseAgentRunner(AppRunner):
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
parameters={
"type": "object",
"properties": {},
"required": [],
},
parameters=tool_entity.get_llm_parameters_json_schema(),
)
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
@ -252,40 +214,7 @@ class BaseAgentRunner(AppRunner):
"""
update prompt message tool
"""
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
prompt_tool.parameters = tool.get_llm_parameters_json_schema()
return prompt_tool
def create_agent_thought(

View File

@ -198,7 +198,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@ -167,7 +167,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,

View File

@ -161,7 +161,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
),
query=query,
files=list(file_objs),
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
parent_message_id=(
args.get("parent_message_id")
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
else UUID_NIL
),
user_id=user.id,
invoke_from=invoke_from,
extras=extras,

View File

@ -53,6 +53,14 @@ from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
# Maps the entry surface a workflow was invoked from to the HITL surface that
# its resume tokens must be filtered for. Surfaces not in this map fall back to
# the general priority ordering (typically CONSOLE > BACKSTAGE).
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
}
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -340,11 +348,7 @@ class WorkflowResponseConverter:
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;

View File

@ -731,6 +731,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
match invoke_from:
case InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
case InvokeFrom.OPENAPI:
created_from = WorkflowAppLogCreatedFrom.OPENAPI
case InvokeFrom.EXPLORE:
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
case InvokeFrom.WEB_APP:

View File

@ -24,6 +24,7 @@ class UserFrom(StrEnum):
class InvokeFrom(StrEnum):
SERVICE_API = "service-api"
OPENAPI = "openapi"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
@ -42,6 +43,7 @@ class InvokeFrom(StrEnum):
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
InvokeFrom.OPENAPI: "openapi",
}
return source_mapping.get(self, "dev")

View File

@ -47,6 +47,12 @@ from graphon.graph_events import (
)
from graphon.node_events import NodeRunResult
from libs.datetime_utils import naive_utc_now
from services.workflow.inspector_events import (
publish_node_changed as _inspector_publish_node_changed,
)
from services.workflow.inspector_events import (
publish_workflow_completed as _inspector_publish_workflow_completed,
)
@dataclass(slots=True)
@ -163,6 +169,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
execution = self._get_workflow_execution()
@ -173,6 +180,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
execution = self._get_workflow_execution()
@ -184,6 +192,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._fail_running_node_executions(error_message=event.error)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
execution = self._get_workflow_execution()
@ -194,6 +203,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._fail_running_node_executions(error_message=execution.error_message or "")
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
@ -241,6 +251,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot
_inspector_publish_node_changed(workflow_run_id=execution.id_, node_id=event.node_id, status="running")
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -248,6 +259,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
domain_execution.error = event.error
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
_inspector_publish_node_changed(
workflow_run_id=self._get_workflow_execution().id_,
node_id=domain_execution.node_id,
status="retry",
)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -257,6 +273,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event.finished_at,
)
_inspector_publish_node_changed(
workflow_run_id=self._get_workflow_execution().id_,
node_id=domain_execution.node_id,
status="succeeded",
)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -267,6 +288,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
error=event.error,
finished_at=event.finished_at,
)
_inspector_publish_node_changed(
workflow_run_id=self._get_workflow_execution().id_,
node_id=domain_execution.node_id,
status="failed",
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@ -277,6 +303,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
error=event.error,
finished_at=event.finished_at,
)
_inspector_publish_node_changed(
workflow_run_id=self._get_workflow_execution().id_,
node_id=domain_execution.node_id,
status="exception",
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)

View File

@ -42,7 +42,6 @@ from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetPro
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -282,8 +281,7 @@ class IndexingRunner:
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
if dify_config.BILLING_ENABLED:
count = len(extract_settings)
batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
if count > batch_upload_limit:

View File

@ -3,6 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator
from typing import Any, cast
from urllib.parse import unquote
import httpx
from pydantic import BaseModel
@ -53,6 +54,9 @@ else:
logger = logging.getLogger(__name__)
PLUGIN_DAEMON_MAX_PATH_LENGTH = 4096
PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH = 8
_httpx_client: httpx.Client = get_pooled_http_client(
"plugin_daemon",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
@ -103,6 +107,20 @@ class BasePluginClient:
params: dict[str, Any] | None,
files: dict[str, Any] | None,
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
if len(path) > PLUGIN_DAEMON_MAX_PATH_LENGTH:
raise ValueError(f"Invalid plugin daemon path: path length exceeds {PLUGIN_DAEMON_MAX_PATH_LENGTH}")
decoded_path = path
for _ in range(PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH):
next_decoded_path = unquote(decoded_path)
if next_decoded_path == decoded_path:
break
decoded_path = next_decoded_path
else:
raise ValueError("Invalid plugin daemon path: path is too deeply encoded")
if any(seg == ".." for seg in decoded_path.split("/")):
raise ValueError(f"Invalid plugin daemon path: traversal sequence detected in {path!r}")
url = plugin_daemon_inner_api_baseurl / path
prepared_headers = dict(headers or {})
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY

View File

@ -126,34 +126,89 @@ class Tool(ABC):
message_id: str | None = None,
) -> list[ToolParameter]:
"""
get merged runtime parameters
Get the effective parameter declarations for this tool.
Runtime parameters override declared parameters by name and append new
parameters, but the returned list is always detached from the tool's
cached declarations so callers can safely mutate it while building
downstream schemas.
:return: merged runtime parameters
"""
parameters = self.entity.parameters
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
parameters = [deepcopy(parameter) for parameter in self.entity.parameters or []]
user_parameters = [
deepcopy(parameter)
for parameter in self.get_runtime_parameters(
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
or []
]
parameter_indexes = {parameter.name: index for index, parameter in enumerate(parameters)}
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
break
else:
# add new parameter
existing_index = parameter_indexes.get(parameter.name)
if existing_index is None:
parameter_indexes[parameter.name] = len(parameters)
parameters.append(parameter)
continue
parameters[existing_index] = parameter
return parameters
def get_llm_parameters_json_schema(
self,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
) -> dict[str, Any]:
"""Build the model-visible JSON schema from effective tool parameters.
Hidden/manual parameters stay available for invocation preparation on the
API side, but are intentionally omitted from the LLM-facing schema.
"""
schema: dict[str, Any] = {
"type": "object",
"properties": {},
"required": [],
}
for parameter in self.get_merged_runtime_parameters(
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
):
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
parameter_schema: dict[str, Any] = (
{
"type": parameter.type.as_normal_type(),
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else deepcopy(parameter.input_schema)
)
parameter_schema.setdefault("description", parameter.llm_description or "")
if parameter.type == ToolParameter.ToolParameterType.SELECT and parameter.options:
parameter_schema["enum"] = [option.value for option in parameter.options]
schema["properties"][parameter.name] = parameter_schema
if parameter.required:
schema["required"].append(parameter.name)
return schema
def create_image_message(
self,
image: str,

View File

@ -63,7 +63,7 @@ def _get_surface_form_token(
*,
surface: HumanInputSurface | None,
) -> str | None:
if surface == HumanInputSurface.SERVICE_API:
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
for recipient_type, token in recipients:
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
return token

View File

@ -11,13 +11,15 @@ from models.human_input import RecipientType
class HumanInputSurface(StrEnum):
SERVICE_API = "service_api"
CONSOLE = "console"
OPENAPI = "openapi"
# Service API is intentionally narrower than other surfaces: app-token callers
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
# should only be able to act on end-user web forms, not internal console flows.
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
}
# A single HITL form can have multiple recipient records; this shared priority

View File

@ -472,6 +472,9 @@ class DifyNodeFactory(NodeFactory):
if issubclass(node_class, DifyAgentNode):
from clients.agent_backend import AgentBackendRunEventAdapter, AgentBackendRunRequestBuilder
from clients.agent_backend.factory import create_agent_backend_run_client
from core.workflow.nodes.agent_v2.file_tenant_validator import UploadFileTenantValidator
from core.workflow.nodes.agent_v2.output_failure_orchestrator import OutputFailureOrchestrator
from core.workflow.nodes.agent_v2.output_type_checker import PerOutputTypeChecker
return {
"binding_resolver": WorkflowAgentBindingResolver(),
@ -486,6 +489,11 @@ class DifyNodeFactory(NodeFactory):
),
"event_adapter": AgentBackendRunEventAdapter(),
"output_adapter": WorkflowAgentOutputAdapter(),
# Stage 4 §5/§7: per-output validation + failure orchestration. The
# tenant validator queries upload_files so it stays cheap when
# outputs contain no file refs.
"type_checker": PerOutputTypeChecker(file_validator=UploadFileTenantValidator()),
"failure_orchestrator": OutputFailureOrchestrator(),
}
return {
"strategy_resolver": self._agent_strategy_resolver,

View File

@ -7,9 +7,11 @@ from clients.agent_backend import (
AgentBackendError,
AgentBackendHTTPError,
AgentBackendInternalEventType,
AgentBackendRunCancelledInternalEvent,
AgentBackendRunClient,
AgentBackendRunEventAdapter,
AgentBackendRunFailedInternalEvent,
AgentBackendRunPausedInternalEvent,
AgentBackendRunSucceededInternalEvent,
AgentBackendStreamError,
AgentBackendStreamInternalEvent,
@ -21,10 +23,18 @@ from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
from models.agent_config_entities import WorkflowNodeJobConfig
from .binding_resolver import WorkflowAgentBindingError, WorkflowAgentBindingResolver
from .entities import DifyAgentNodeData
from .output_adapter import WorkflowAgentOutputAdapter
from .output_failure_orchestrator import (
FailedOutput,
OutputFailureDecision,
OutputFailureKind,
OutputFailureOrchestrator,
)
from .output_type_checker import OutputTypeCheckOutcome, PerOutputTypeChecker
from .runtime_request_builder import (
WorkflowAgentRuntimeBuildContext,
WorkflowAgentRuntimeRequestBuilder,
@ -36,6 +46,17 @@ if TYPE_CHECKING:
from graphon.runtime import GraphRuntimeState
# Stage 4 §5+§7: the terminal events that `_consume_event_stream` may return.
# Stream + started events are filtered out before we yield; transport errors
# are surfaced as a separate StreamCompletedEvent in the second tuple slot.
_TerminalAgentBackendEvent = (
AgentBackendRunSucceededInternalEvent
| AgentBackendRunFailedInternalEvent
| AgentBackendRunCancelledInternalEvent
| AgentBackendRunPausedInternalEvent
)
class DifyAgentNode(Node[DifyAgentNodeData]):
node_type = BuiltinNodeTypes.AGENT
@ -51,6 +72,8 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
agent_backend_client: AgentBackendRunClient,
event_adapter: AgentBackendRunEventAdapter,
output_adapter: WorkflowAgentOutputAdapter,
type_checker: PerOutputTypeChecker,
failure_orchestrator: OutputFailureOrchestrator,
) -> None:
super().__init__(
node_id=node_id,
@ -63,6 +86,8 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
self._agent_backend_client = agent_backend_client
self._event_adapter = event_adapter
self._output_adapter = output_adapter
self._type_checker = type_checker
self._failure_orchestrator = failure_orchestrator
@classmethod
def version(cls) -> str:
@ -86,6 +111,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
}
}
# ──── Setup: resolve binding once + extract declared outputs for stage 4 checks ────
try:
bundle = self._binding_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
@ -93,32 +119,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
workflow_id=workflow_id,
node_id=self._node_id,
)
runtime_request = self._runtime_request_builder.build(
WorkflowAgentRuntimeBuildContext(
dify_context=dify_ctx,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=self._node_id,
node_execution_id=self.id,
variable_pool=self.graph_runtime_state.variable_pool,
binding=bundle.binding,
agent=bundle.agent,
snapshot=bundle.snapshot,
)
)
inputs = {"agent_backend_request": runtime_request.redacted_request}
metadata = dict(runtime_request.metadata)
process_data = {
"agent_id": bundle.agent.id,
"agent_config_snapshot_id": bundle.snapshot.id,
"binding_id": bundle.binding.id,
}
create_response = self._agent_backend_client.create_run(runtime_request.request)
metadata["agent_backend"] = {
**dict(metadata.get("agent_backend") or {}),
"run_id": create_response.run_id,
"status": create_response.status,
}
except WorkflowAgentBindingError as error:
yield self._failure_event(
inputs=inputs,
@ -128,37 +128,195 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
error_type=error.error_code,
)
return
except WorkflowAgentRuntimeRequestBuildError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=error.error_code,
process_data = {
"agent_id": bundle.agent.id,
"agent_config_snapshot_id": bundle.snapshot.id,
"binding_id": bundle.binding.id,
}
# Stage 4 §4.1 (D-3): use effective outputs so defaults flow through both
# the backend request and the post-run type check.
node_job = WorkflowNodeJobConfig.model_validate(bundle.binding.node_job_config_dict)
effective_outputs = list(
WorkflowAgentRuntimeRequestBuilder.effective_declared_outputs(list(node_job.declared_outputs))
)
outputs_by_name = {o.name: o for o in effective_outputs}
# ──── Retry loop (Stage 4 §7) ────
attempt = 0
while True:
try:
runtime_request = self._runtime_request_builder.build(
WorkflowAgentRuntimeBuildContext(
dify_context=dify_ctx,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=self._node_id,
node_execution_id=self.id,
variable_pool=self.graph_runtime_state.variable_pool,
binding=bundle.binding,
agent=bundle.agent,
snapshot=bundle.snapshot,
attempt=attempt,
)
)
except WorkflowAgentRuntimeRequestBuildError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=error.error_code,
)
return
except Exception as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type="agent_workflow_node_runtime_error",
)
return
# Capture inputs only from the first attempt so retry doesn't churn the
# node's "inputs" payload that ends up in the workflow detail view.
if attempt == 0:
inputs = {"agent_backend_request": runtime_request.redacted_request}
metadata = dict(runtime_request.metadata)
metadata["attempt"] = attempt
try:
create_response = self._agent_backend_client.create_run(runtime_request.request)
except AgentBackendError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=self._agent_backend_error_type(error),
)
return
metadata["agent_backend"] = {
**dict(metadata.get("agent_backend") or {}),
"run_id": create_response.run_id,
"status": create_response.status,
}
terminal_event, exhausted = self._consume_event_stream(create_response.run_id, metadata)
if exhausted is not None:
# Streaming error / unexpected end — surface immediately without
# retrying because the failure is transport-level.
yield exhausted
return
if terminal_event is None:
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_stream_exhausted_result(
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
# Non-success terminal (failed / cancelled / paused) skips per-output
# post-processing — the backend itself already failed.
if not isinstance(terminal_event, AgentBackendRunSucceededInternalEvent):
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_failure_result(
event=terminal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
# ──── Stage 4: per-output type check ────
type_check = self._type_checker.check(
declared_outputs=effective_outputs,
raw_output=terminal_event.output,
tenant_id=dify_ctx.tenant_id,
)
return
except AgentBackendError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=self._agent_backend_error_type(error),
self._record_type_check_metadata(metadata, type_check)
if not type_check.has_failures:
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_success_result(
event=terminal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
# ──── Stage 4: orchestrate retry / default / fail ────
failures = [
FailedOutput(
declared=outputs_by_name[result.name],
failure_kind=OutputFailureKind.TYPE_CHECK,
reason=result.reason,
)
for result in type_check.failures
if result.name in outputs_by_name
]
outcome = self._failure_orchestrator.decide(failures=failures, current_attempt=attempt)
metadata["output_failure_decision"] = outcome.decision.value
metadata["output_failure_reason"] = outcome.primary_reason
if outcome.decision == OutputFailureDecision.RETRY:
attempt = outcome.next_attempt
continue
if outcome.decision == OutputFailureDecision.USE_DEFAULT:
patched_event = self._patch_event_with_defaults(terminal_event, outcome.per_output_actions)
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_success_result(
event=patched_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
error_type = (
"output_type_check_failed_fail_branch"
if outcome.decision == OutputFailureDecision.TAKE_FAIL_BRANCH
else "output_type_check_failed"
)
return
except Exception as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type="agent_workflow_node_runtime_error",
error=outcome.primary_reason,
error_type=error_type,
)
return
def _consume_event_stream(
self,
run_id: str,
metadata: dict[str, Any],
) -> tuple[
_TerminalAgentBackendEvent | None,
StreamCompletedEvent | None,
]:
"""Consume the SSE stream for one Agent backend run.
Returns a 2-tuple ``(terminal_event, transport_failure)``:
- ``terminal_event``: the first non-stream/non-started internal event,
or ``None`` if the stream ended without one.
- ``transport_failure``: a populated ``StreamCompletedEvent`` when the
stream itself errored (backend/HTTP/protocol fault). Mutually
exclusive with ``terminal_event``.
"""
stream_event_count = 0
try:
for public_event in self._agent_backend_client.stream_events(create_response.run_id):
for public_event in self._agent_backend_client.stream_events(run_id):
stream_event_count += 1
for internal_event in self._event_adapter.adapt(public_event):
if internal_event.type == AgentBackendInternalEventType.RUN_STARTED:
@ -171,58 +329,78 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
**dict(metadata.get("agent_backend") or {}),
"stream_event_count": stream_event_count,
}
if isinstance(internal_event, AgentBackendRunSucceededInternalEvent):
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_success_result(
event=internal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
# Narrow to the 4 known terminal event types so the caller
# can hand the result to ``build_failure_result`` (which is
# typed against the union). Anything else is a protocol-
# level surprise we surface as a stream error.
if isinstance(
internal_event,
AgentBackendRunFailedInternalEvent,
) or internal_event.type in {
AgentBackendInternalEventType.RUN_CANCELLED,
AgentBackendInternalEventType.RUN_PAUSED,
}:
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_failure_result(
event=internal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
AgentBackendRunSucceededInternalEvent
| AgentBackendRunFailedInternalEvent
| AgentBackendRunCancelledInternalEvent
| AgentBackendRunPausedInternalEvent,
):
return internal_event, None
return None, self._failure_event(
inputs={},
process_data={},
metadata=metadata,
error=f"Unexpected internal event type {internal_event.type!r}",
error_type="agent_backend_stream_error",
)
except AgentBackendError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
return None, self._failure_event(
inputs={},
process_data={},
metadata=metadata,
error=str(error),
error_type=self._agent_backend_error_type(error),
)
return
except Exception as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
return None, self._failure_event(
inputs={},
process_data={},
metadata=metadata,
error=str(error),
error_type="agent_backend_stream_error",
)
return
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_stream_exhausted_result(
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return None, None
@staticmethod
def _record_type_check_metadata(metadata: dict[str, Any], outcome: OutputTypeCheckOutcome) -> None:
# Surface enough detail in metadata for Inspector / debug logs without
# leaking the raw failing values (which may be sensitive).
metadata["output_type_check"] = {
"passed": not outcome.has_failures,
"results": [
{
"name": r.name,
"type": r.declared_type.value,
"status": r.status.value,
"reason": r.reason,
}
for r in outcome.results
],
}
@staticmethod
def _patch_event_with_defaults(
event: AgentBackendRunSucceededInternalEvent,
per_output_actions: Mapping[str, Any],
) -> AgentBackendRunSucceededInternalEvent:
"""Merge USE_DEFAULT replacements into the success event's output dict.
The event is a frozen dataclass / Pydantic model; we copy with the
replacements applied so downstream code (output_adapter normalize) sees
the patched payload.
"""
if not per_output_actions:
return event
original = event.output if isinstance(event.output, Mapping) else {}
patched_output: dict[str, Any] = dict(original)
patched_output.update(per_output_actions)
return event.model_copy(update={"output": patched_output})
@staticmethod
def _failure_event(

View File

@ -0,0 +1,46 @@
"""Tenant-scope validator for file refs produced by Agent backend outputs.
Stage 4 §5.3: every file output the Agent backend produces must resolve to an
``upload_files`` row that belongs to the current tenant; cross-tenant file
references must never be plumbed downstream. ``PerOutputTypeChecker`` accepts a
``FileTenantValidator`` Protocol so unit tests can stub the check without
hitting Postgres.
This module supplies the production implementation that queries the
``upload_files`` table via SQLAlchemy.
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.exc import DataError, SQLAlchemyError
from core.db.session_factory import session_factory
from models.model import UploadFile
class UploadFileTenantValidator:
"""Production ``FileTenantValidator`` backed by the ``upload_files`` table.
Returns ``False`` (rejects the file) on any pathological input: empty
file_id/tenant_id, non-UUID file_id format, DB errors. The Agent backend
may produce arbitrary strings inside file refs since the schema only
asserts ``{type: string}``; treating malformed refs as invalid keeps the
workflow node from crashing on garbage backend output.
"""
def is_owned_by_tenant(self, *, file_id: str, tenant_id: str) -> bool:
if not file_id or not tenant_id:
return False
try:
UUID(file_id)
except (ValueError, TypeError):
return False
try:
with session_factory.create_session() as session:
owner_tenant_id = session.scalar(select(UploadFile.tenant_id).where(UploadFile.id == file_id))
except (DataError, SQLAlchemyError):
return False
return owner_tenant_id == tenant_id

View File

@ -0,0 +1,201 @@
"""Per-output failure decision logic for Workflow Agent Node v2.
Stage 4 §7. Pure orchestration: given a set of per-output failures plus their
configured ``DeclaredOutputFailureStrategy``, decide whether the workflow node
should retry the Agent backend run, take a fail branch, fall back to default
values, or fail outright.
The orchestrator is intentionally state-free. The caller (``agent_node._run``)
owns the retry attempt counter and is responsible for actually issuing the
re-run; this module only computes the decision.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from enum import StrEnum
from typing import Any
from models.agent_config_entities import (
DeclaredOutputConfig,
DeclaredOutputFailureStrategy,
OutputErrorStrategy,
)
class OutputFailureKind(StrEnum):
"""Why the per-output post-processing failed."""
TYPE_CHECK = "type_check"
OUTPUT_CHECK = "output_check"
class OutputFailureDecision(StrEnum):
"""What the runtime should do after collecting per-output failures."""
# Re-invoke Agent backend (entire node re-runs). Used while retry budget
# remains for at least one failed output.
RETRY = "retry"
# Replace each failed output's value with its declared default_value and
# surface the run as successful.
USE_DEFAULT = "use_default"
# Mark the workflow node as failed; halt downstream propagation.
FAIL_NODE = "fail_node"
# Surface the node as exception → route through fail branch outbound edge.
TAKE_FAIL_BRANCH = "take_fail_branch"
@dataclass(frozen=True, slots=True)
class FailedOutput:
"""One failed output that the orchestrator needs to reason about."""
declared: DeclaredOutputConfig
failure_kind: OutputFailureKind
reason: str | None = None
@dataclass(frozen=True, slots=True)
class OutputFailureOutcome:
"""Outcome of orchestrating one batch of failures.
``per_output_actions`` is non-empty only when ``decision == USE_DEFAULT``;
it maps the failed output's ``name`` to the value that should be merged
into the node's outputs in place of the failed value.
"""
decision: OutputFailureDecision
per_output_actions: Mapping[str, Any]
next_attempt: int
primary_reason: str
failure_kinds: tuple[OutputFailureKind, ...]
# Stage 4 §7 — precedence used to merge differing per-output strategies into a
# single node-level decision when multiple outputs fail at once.
# Smaller integer = lower priority. FAIL_BRANCH wins overall.
_STRATEGY_TERMINAL_RANK: dict[OutputErrorStrategy, int] = {
OutputErrorStrategy.DEFAULT_VALUE: 0,
OutputErrorStrategy.STOP: 1,
OutputErrorStrategy.FAIL_BRANCH: 2,
}
class OutputFailureOrchestrator:
"""Pure decision engine for per-output failure handling."""
def decide(
self,
*,
failures: list[FailedOutput],
current_attempt: int,
) -> OutputFailureOutcome:
"""Compute the next action given a non-empty list of failures.
``current_attempt`` is zero-indexed: ``0`` means the failures come
from the first backend run, ``1`` from the first retry, etc. The
returned ``next_attempt`` is the value the caller should use for the
next iteration when ``decision == RETRY``.
"""
if not failures:
raise ValueError("OutputFailureOrchestrator.decide() requires at least one failure")
# Stage 4 §7: any output whose retry budget is not yet exhausted forces
# a whole-node retry. The effective max-retries is the *maximum* across
# all currently-failed outputs so retry continues until every output's
# budget is spent (or it goes ready).
retry_budget = max(
(f.declared.failure_strategy.retry.max_retries if f.declared.failure_strategy.retry.enabled else 0)
for f in failures
)
if current_attempt < retry_budget:
return OutputFailureOutcome(
decision=OutputFailureDecision.RETRY,
per_output_actions={},
next_attempt=current_attempt + 1,
primary_reason=self._summarize(failures),
failure_kinds=self._failure_kinds(failures),
)
# Retry budget exhausted: collapse each per-output terminal action into
# a single node-level decision via the precedence table.
merged = self._merge_terminal_decisions(failures)
per_output_actions: dict[str, Any] = {}
if merged == OutputFailureDecision.USE_DEFAULT:
for failure in failures:
strategy = failure.declared.failure_strategy
if strategy.on_failure == OutputErrorStrategy.DEFAULT_VALUE:
per_output_actions[failure.declared.name] = strategy.default_value
return OutputFailureOutcome(
decision=merged,
per_output_actions=per_output_actions,
next_attempt=current_attempt,
primary_reason=self._summarize(failures),
failure_kinds=self._failure_kinds(failures),
)
@staticmethod
def _merge_terminal_decisions(failures: list[FailedOutput]) -> OutputFailureDecision:
# Pick the highest-precedence strategy across all failures.
winning: OutputErrorStrategy = OutputErrorStrategy.DEFAULT_VALUE
winning_rank = -1
for failure in failures:
strategy = failure.declared.failure_strategy.on_failure
rank = _STRATEGY_TERMINAL_RANK[strategy]
if rank > winning_rank:
winning = strategy
winning_rank = rank
return _TERMINAL_STRATEGY_TO_DECISION[winning]
@staticmethod
def _summarize(failures: list[FailedOutput]) -> str:
parts: list[str] = []
for failure in failures:
reason = failure.reason or "no reason recorded"
parts.append(f"{failure.declared.name}[{failure.failure_kind.value}]: {reason}")
return "; ".join(parts)
@staticmethod
def _failure_kinds(failures: list[FailedOutput]) -> tuple[OutputFailureKind, ...]:
seen: list[OutputFailureKind] = []
for failure in failures:
if failure.failure_kind not in seen:
seen.append(failure.failure_kind)
return tuple(seen)
_TERMINAL_STRATEGY_TO_DECISION: dict[OutputErrorStrategy, OutputFailureDecision] = {
OutputErrorStrategy.STOP: OutputFailureDecision.FAIL_NODE,
OutputErrorStrategy.DEFAULT_VALUE: OutputFailureDecision.USE_DEFAULT,
OutputErrorStrategy.FAIL_BRANCH: OutputFailureDecision.TAKE_FAIL_BRANCH,
}
def retry_idempotency_key(
*,
workflow_run_id: str | None,
node_execution_id: str,
attempt: int,
) -> str:
"""Compute the Agent backend ``idempotency_key`` for a given attempt.
Stage 4 §7 / decision D-4: each retry must use a distinct key so the
backend's protocol-level dedup doesn't return the previous run's id.
First attempt (attempt=0) matches the pre-stage-4 key shape so logs stay
backward compatible.
"""
base = f"{workflow_run_id}:{node_execution_id}" if workflow_run_id else node_execution_id
if attempt <= 0:
return base
return f"{base}:retry-{attempt}"
def build_failure_strategy_for(declared: DeclaredOutputConfig) -> DeclaredOutputFailureStrategy:
"""Convenience accessor that always returns a populated strategy.
Existing callers that read ``output.failure_strategy`` already get a
populated default thanks to the BaseModel default_factory, but this helper
documents the contract and gives the orchestrator's tests a single hook.
"""
return declared.failure_strategy

View File

@ -0,0 +1,244 @@
"""Per-output runtime type checker for Workflow Agent Node v2.
Stage 4 §5: after Agent backend returns ``run_succeeded.data.output`` (a JSON
object that already passed the ``dify.output`` layer's JSON Schema validation
inside pydantic-ai), the API side runs a *second* pass that:
1. Locates each declared output by name in the backend payload.
2. Asserts the value's shape against the declared ``DeclaredOutputType``
(including array items and file ref objects).
3. For file outputs, verifies the referenced ``file_id`` resolves to a file
owned by the current tenant (PRD §5.3 file output reference safety).
The checker is intentionally pure: it takes data in and returns a structured
outcome out. ``FileTenantValidator`` is injected as a Protocol so unit tests
can stub tenant resolution without DB access.
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, Protocol
from models.agent_config_entities import (
DeclaredArrayItem,
DeclaredOutputConfig,
DeclaredOutputType,
)
class OutputTypeCheckStatus(StrEnum):
"""Lifecycle status of a single declared output after type check."""
READY = "ready"
NOT_PRODUCED = "not_produced"
TYPE_CHECK_FAILED = "type_check_failed"
@dataclass(frozen=True, slots=True)
class OutputTypeCheckResult:
"""Outcome of type-checking one declared output.
``value`` carries the raw payload value as it appeared in the backend
response. For ``TYPE_CHECK_FAILED`` results the value is preserved so the
Failure Orchestrator can decide whether to surface it (e.g. for debug
metadata) — it is **not** safe to feed into downstream nodes.
"""
name: str
declared_type: DeclaredOutputType
status: OutputTypeCheckStatus
value: Any
reason: str | None = None
@dataclass(frozen=True, slots=True)
class OutputTypeCheckOutcome:
"""Aggregate per-output type-check results for one Agent backend run."""
results: tuple[OutputTypeCheckResult, ...]
@property
def failures(self) -> tuple[OutputTypeCheckResult, ...]:
return tuple(r for r in self.results if r.status == OutputTypeCheckStatus.TYPE_CHECK_FAILED)
@property
def has_failures(self) -> bool:
return bool(self.failures)
def by_name(self) -> dict[str, OutputTypeCheckResult]:
return {r.name: r for r in self.results}
class FileTenantValidator(Protocol):
"""Verify a file ref resolves to a file owned by the given tenant."""
def is_owned_by_tenant(self, *, file_id: str, tenant_id: str) -> bool: ...
# Recognized aliases the Agent backend (or pydantic-ai) may produce for the
# canonical file id field. The canonical spec form is ``file_id`` (§5.2).
_FILE_ID_KEYS: tuple[str, ...] = ("file_id", "upload_file_id", "tool_file_id")
class PerOutputTypeChecker:
"""Validate that each declared output is present and shaped correctly.
The checker handles array items recursively and is opinionated about file
refs: only dicts with at least one recognized id key plus a tenant-scope
match pass. Stage 4 §5.2 + §5.3.
"""
def __init__(self, file_validator: FileTenantValidator) -> None:
self._file_validator = file_validator
def check(
self,
*,
declared_outputs: list[DeclaredOutputConfig],
raw_output: Mapping[str, Any] | Any,
tenant_id: str,
) -> OutputTypeCheckOutcome:
"""Run type check for every declared output.
``raw_output`` should be ``run_succeeded.data.output``. The backend
always returns a dict because the ``dify.output`` layer wraps every
schema in a top-level object; if it isn't a dict (e.g. backend
misbehaving) every required output is flagged as ``TYPE_CHECK_FAILED``.
"""
results: list[OutputTypeCheckResult] = []
payload = raw_output if isinstance(raw_output, Mapping) else None
for declared in declared_outputs:
if payload is None:
results.append(
OutputTypeCheckResult(
name=declared.name,
declared_type=declared.type,
status=OutputTypeCheckStatus.TYPE_CHECK_FAILED,
value=raw_output,
reason="Backend output is not a JSON object.",
)
)
continue
if declared.name not in payload:
if declared.required:
results.append(
OutputTypeCheckResult(
name=declared.name,
declared_type=declared.type,
status=OutputTypeCheckStatus.TYPE_CHECK_FAILED,
value=None,
reason=f"Required output {declared.name!r} is missing from backend payload.",
)
)
else:
results.append(
OutputTypeCheckResult(
name=declared.name,
declared_type=declared.type,
status=OutputTypeCheckStatus.NOT_PRODUCED,
value=None,
)
)
continue
value = payload[declared.name]
failure_reason = self._validate_value(
declared_type=declared.type,
value=value,
tenant_id=tenant_id,
array_item=declared.array_item,
)
if failure_reason is None:
results.append(
OutputTypeCheckResult(
name=declared.name,
declared_type=declared.type,
status=OutputTypeCheckStatus.READY,
value=value,
)
)
else:
results.append(
OutputTypeCheckResult(
name=declared.name,
declared_type=declared.type,
status=OutputTypeCheckStatus.TYPE_CHECK_FAILED,
value=value,
reason=failure_reason,
)
)
return OutputTypeCheckOutcome(results=tuple(results))
def _validate_value(
self,
*,
declared_type: DeclaredOutputType,
value: Any,
tenant_id: str,
array_item: DeclaredArrayItem | None,
) -> str | None:
"""Return ``None`` on success, or a human-readable failure reason."""
if declared_type == DeclaredOutputType.STRING:
if not isinstance(value, str):
return f"expected string, got {type(value).__name__}"
return None
if declared_type == DeclaredOutputType.NUMBER:
# ``bool`` is a subclass of int in Python; PRD treats numbers as
# strictly numeric so we reject bools here.
if not isinstance(value, (int, float)) or isinstance(value, bool):
return f"expected number, got {type(value).__name__}"
return None
if declared_type == DeclaredOutputType.BOOLEAN:
if not isinstance(value, bool):
return f"expected boolean, got {type(value).__name__}"
return None
if declared_type == DeclaredOutputType.OBJECT:
if not isinstance(value, Mapping):
return f"expected object, got {type(value).__name__}"
return None
if declared_type == DeclaredOutputType.ARRAY:
if not isinstance(value, list):
return f"expected array, got {type(value).__name__}"
if array_item is None:
# Defensive: the model validator should have populated this; if
# absent, accept any items rather than crash.
return None
for index, item in enumerate(value):
item_reason = self._validate_value(
declared_type=array_item.type,
value=item,
tenant_id=tenant_id,
array_item=None,
)
if item_reason is not None:
return f"items[{index}]: {item_reason}"
return None
if declared_type == DeclaredOutputType.FILE:
return self._validate_file_value(value=value, tenant_id=tenant_id)
# Defensive: future DeclaredOutputType members reach this branch and
# should fail loudly so we never silently accept unknown shapes.
return f"unsupported declared_type={declared_type!r}"
def _validate_file_value(self, *, value: Any, tenant_id: str) -> str | None:
if not isinstance(value, Mapping):
return f"expected file ref object, got {type(value).__name__}"
file_id = self._extract_file_id(value)
if file_id is None:
return "file ref missing a recognized file_id field"
if not self._file_validator.is_owned_by_tenant(file_id=file_id, tenant_id=tenant_id):
return f"file_id {file_id!r} is not accessible to tenant {tenant_id!r}"
return None
@staticmethod
def _extract_file_id(value: Mapping[str, Any]) -> str | None:
for key in _FILE_ID_KEYS:
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@ -0,0 +1,268 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol, cast
from dify_agent.layers.dify_plugin import (
DifyPluginCredentialValue,
DifyPluginToolConfig,
DifyPluginToolCredentialType,
DifyPluginToolParameter,
DifyPluginToolParameterForm,
DifyPluginToolsLayerConfig,
)
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.errors import (
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.tool_manager import ToolManager
from models.agent_config_entities import AgentSoulDifyToolConfig, AgentSoulToolsConfig
from models.provider_ids import ToolProviderID
class WorkflowAgentPluginToolsBuildError(ValueError):
"""Raised when Agent Soul tools cannot be prepared for Agent backend."""
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
class AgentToolRuntimeProvider(Protocol):
def get_agent_tool_runtime(
self,
tenant_id: str,
app_id: str,
agent_tool: AgentToolEntity,
user_id: str | None = None,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
variable_pool: Any | None = None,
) -> Tool: ...
class WorkflowAgentPluginToolsBuilder:
"""Prepare Agent Soul Dify Plugin Tools for the public Agent backend DTO."""
def __init__(self, *, tool_runtime_provider: AgentToolRuntimeProvider | None = None) -> None:
self._tool_runtime_provider = tool_runtime_provider or ToolManager
def build(
self,
*,
tenant_id: str,
app_id: str,
user_id: str | None,
tools: AgentSoulToolsConfig,
invoke_from: InvokeFrom,
) -> DifyPluginToolsLayerConfig | None:
"""Resolve user-selected Dify Plugin Tools into the Agent backend DTO.
``invoke_from`` is the *real* runtime caller category (DEBUGGER for a
Composer test run, SERVICE_API / WEB_APP for a published run). It must
be threaded through to :class:`ToolManager` so credential quotas, rate
limits, and audit tags match the actual call site.
"""
enabled_tools = [tool for tool in tools.dify_tools if tool.enabled]
if not enabled_tools:
return None
prepared: list[DifyPluginToolConfig] = []
seen_names: set[str] = set()
for tool_config in enabled_tools:
agent_tool = self._to_agent_tool_entity(tool_config)
tool_runtime = self._fetch_tool_runtime(
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
agent_tool=agent_tool,
invoke_from=invoke_from,
tool_config=tool_config,
)
exposed_name = self._exposed_tool_name(tool_config)
if exposed_name in seen_names:
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_name_duplicated",
f"Duplicate Dify Plugin Tool name {exposed_name!r}.",
)
seen_names.add(exposed_name)
prepared.append(self._to_backend_tool_config(tool_config, tool_runtime, exposed_name))
return DifyPluginToolsLayerConfig(tools=prepared)
def _fetch_tool_runtime(
self,
*,
tenant_id: str,
app_id: str,
user_id: str | None,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom,
tool_config: AgentSoulDifyToolConfig,
) -> Tool:
"""Resolve the API-side ``Tool`` runtime, mapping fetch errors to
Inspector-friendly error codes so callers can render distinct UX for
"tool definition gone" vs "credential failed".
"""
try:
return self._tool_runtime_provider.get_agent_tool_runtime(
tenant_id=tenant_id,
app_id=app_id,
agent_tool=agent_tool,
user_id=user_id,
invoke_from=invoke_from,
variable_pool=None,
)
except ToolProviderNotFoundError as exc:
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_declaration_not_found",
f"Dify Plugin Tool {tool_config.tool_name!r} declaration not found: {exc}",
) from exc
except ToolProviderCredentialValidationError as exc:
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_credential_invalid",
f"Dify Plugin Tool {tool_config.tool_name!r} credential validation failed: {exc}",
) from exc
except ValueError as exc:
# ToolManager raises bare ValueError when the agent tool's
# ``runtime`` / runtime parameters are missing. Surface it under a
# narrower error code than a generic "declaration not found" so
# frontend can render an actionable hint.
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_config_invalid",
f"Dify Plugin Tool {tool_config.tool_name!r} runtime construction failed: {exc}",
) from exc
@staticmethod
def _to_agent_tool_entity(tool_config: AgentSoulDifyToolConfig) -> AgentToolEntity:
return AgentToolEntity(
provider_type=ToolProviderType.value_of(tool_config.provider_type),
provider_id=WorkflowAgentPluginToolsBuilder._provider_id(tool_config),
tool_name=tool_config.tool_name,
tool_parameters=dict(tool_config.runtime_parameters),
credential_id=tool_config.credential_ref.id if tool_config.credential_ref else None,
)
@staticmethod
def _provider_id(tool_config: AgentSoulDifyToolConfig) -> str:
if tool_config.provider_id:
return tool_config.provider_id
assert tool_config.plugin_id is not None
assert tool_config.provider is not None
return f"{tool_config.plugin_id}/{tool_config.provider}"
@staticmethod
def _exposed_tool_name(tool_config: AgentSoulDifyToolConfig) -> str:
# Stage 3.1 decision: no user rename yet. Keep the model-visible tool
# name aligned with the plugin declaration identity.
return tool_config.tool_name
def _to_backend_tool_config(
self,
tool_config: AgentSoulDifyToolConfig,
tool_runtime: Tool,
exposed_name: str,
) -> DifyPluginToolConfig:
runtime = tool_runtime.runtime
if runtime is None:
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_config_invalid",
f"Dify Plugin Tool {tool_config.tool_name!r} has no runtime.",
)
provider_id = self._provider_id(tool_config)
plugin_id, provider = self._plugin_provider(tool_config, provider_id)
parameters = [
DifyPluginToolParameter.model_validate(parameter.model_dump(mode="json"))
for parameter in tool_runtime.get_merged_runtime_parameters()
]
runtime_parameters = self._runtime_parameters(tool_runtime, parameters)
description = tool_config.description
if description is None and tool_runtime.entity.description is not None:
description = tool_runtime.entity.description.llm
return DifyPluginToolConfig(
plugin_id=plugin_id,
provider=provider,
tool_name=tool_config.tool_name,
credential_type=self._credential_type(tool_config, runtime.credentials),
name=exposed_name,
description=description,
credentials=self._normalize_credentials(runtime.credentials, tool_name=tool_config.tool_name),
runtime_parameters=runtime_parameters,
parameters=parameters,
parameters_json_schema=cast(dict[str, Any], tool_runtime.get_llm_parameters_json_schema()),
)
@staticmethod
def _plugin_provider(tool_config: AgentSoulDifyToolConfig, provider_id: str) -> tuple[str, str]:
if tool_config.plugin_id and tool_config.provider:
return tool_config.plugin_id, tool_config.provider
provider_id_entity = ToolProviderID(provider_id)
return provider_id_entity.plugin_id, provider_id_entity.provider_name
@staticmethod
def _credential_type(
tool_config: AgentSoulDifyToolConfig,
credentials: Mapping[str, Any],
) -> DifyPluginToolCredentialType:
if not credentials and tool_config.credential_type == "unauthorized":
return "unauthorized"
return tool_config.credential_type
@staticmethod
def _runtime_parameters(
tool_runtime: Tool,
parameters: list[DifyPluginToolParameter],
) -> dict[str, Any]:
runtime = tool_runtime.runtime
runtime_parameters = dict(runtime.runtime_parameters if runtime is not None else {})
missing = [
parameter.name
for parameter in parameters
if parameter.form is not DifyPluginToolParameterForm.LLM
and parameter.required
and parameter.default is None
and parameter.name not in runtime_parameters
]
if missing:
names = ", ".join(sorted(missing))
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_runtime_parameter_missing",
f"Dify Plugin Tool {tool_runtime.entity.identity.name!r} is missing runtime parameters: {names}.",
)
return runtime_parameters
@staticmethod
def _normalize_credentials(
credentials: Mapping[str, Any],
*,
tool_name: str,
) -> dict[str, DifyPluginCredentialValue]:
"""Forward only scalar credential values to the Agent backend.
``DifyPluginCredentialValue`` is ``str | int | float | bool | None``.
Refusing non-scalar values (lists, dicts, custom objects) is safer than
``str(value)`` — stringifying a nested OAuth token blob produces a
Python ``repr`` that the plugin daemon cannot use, and we'd rather
surface a clear ``agent_tool_credential_shape_invalid`` than send junk.
"""
normalized: dict[str, DifyPluginCredentialValue] = {}
for key, value in credentials.items():
if isinstance(value, str | int | float | bool) or value is None:
normalized[key] = value
continue
raise WorkflowAgentPluginToolsBuildError(
"agent_tool_credential_shape_invalid",
(
f"Dify Plugin Tool {tool_name!r} credential {key!r} has a non-scalar value "
f"({type(value).__name__}); only str/int/float/bool/None are forwarded to the daemon."
),
)
return normalized

View File

@ -11,13 +11,14 @@ SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
"workflow_context",
"model",
"structured_output",
"tools.dify_tools",
}
)
RESERVED_AGENT_BACKEND_FEATURES = frozenset(
{
"skills_files",
"tools",
"tools.cli_tools",
"knowledge",
"human",
"env",
@ -32,7 +33,7 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
warnings: list[dict[str, str]] = []
soul_dump = agent_soul.model_dump(mode="json")
for section in sorted(RESERVED_AGENT_BACKEND_FEATURES):
value = soul_dump.get(section)
value = _get_nested(soul_dump, section)
has_value = bool(value)
if isinstance(value, dict):
has_value = any(bool(item) for item in value.values())
@ -41,11 +42,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
{
"section": f"agent_soul.{section}",
"code": "agent_backend_layer_not_available",
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend in phase 3.",
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend.",
}
)
reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed")
reserved_status["tools.dify_tools"] = "supported_when_config_valid"
return {
"supported": sorted(SUPPORTED_AGENT_BACKEND_FEATURES),
@ -53,3 +55,12 @@ def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any
"reserved_status": reserved_status,
"unsupported_runtime_warnings": warnings,
}
def _get_nested(value: dict[str, Any], path: str) -> Any:
current: Any = value
for part in path.split("."):
if not isinstance(current, dict):
return None
current = current.get(part)
return current

View File

@ -4,7 +4,8 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Protocol, cast
from dify_agent.protocol import CreateRunRequest, ExecutionContext
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
from dify_agent.protocol import CreateRunRequest
from clients.agent_backend import (
AgentBackendModelConfig,
@ -19,11 +20,17 @@ from graphon.variables.segments import Segment
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
from models.agent_config_entities import (
AgentSoulConfig,
DeclaredArrayItem,
DeclaredOutputConfig,
DeclaredOutputType,
WorkflowNodeJobConfig,
)
from models.agent_config_entities import (
effective_declared_outputs as _effective_declared_outputs,
)
from .output_failure_orchestrator import retry_idempotency_key
from .plugin_tools_builder import WorkflowAgentPluginToolsBuilder, WorkflowAgentPluginToolsBuildError
from .runtime_feature_manifest import build_runtime_feature_manifest
@ -56,6 +63,9 @@ class WorkflowAgentRuntimeBuildContext:
binding: WorkflowAgentNodeBinding
agent: Agent
snapshot: AgentConfigSnapshot
# Stage 4 §7 / D-4: 0 for the first run, then incremented per retry. Drives the
# idempotency key so the backend treats each retry as a fresh request.
attempt: int = 0
@dataclass(frozen=True, slots=True)
@ -75,9 +85,11 @@ class WorkflowAgentRuntimeRequestBuilder:
*,
credentials_provider: CredentialsProvider,
request_builder: AgentBackendRunRequestBuilder | None = None,
plugin_tools_builder: WorkflowAgentPluginToolsBuilder | None = None,
) -> None:
self._credentials_provider = credentials_provider
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
self._plugin_tools_builder = plugin_tools_builder or WorkflowAgentPluginToolsBuilder()
def build(self, context: WorkflowAgentRuntimeBuildContext) -> WorkflowAgentRuntimeRequest:
agent_soul = AgentSoulConfig.model_validate(context.snapshot.config_snapshot_dict)
@ -93,20 +105,44 @@ class WorkflowAgentRuntimeRequestBuilder:
workflow_job_prompt = node_job.workflow_prompt.strip() or "Run this workflow Agent Node for the current run."
user_prompt = workflow_context_prompt.strip() or "Use the current workflow context."
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
try:
tools_layer = self._plugin_tools_builder.build(
tenant_id=context.dify_context.tenant_id,
app_id=context.dify_context.app_id,
user_id=context.dify_context.user_id,
tools=agent_soul.tools,
# Thread the *real* runtime invocation source through to
# ToolManager so credential quotas, rate limits, and audit
# trails match the actual call site (DEBUGGER for draft test
# run, SERVICE_API / WEB_APP for published run).
invoke_from=context.dify_context.invoke_from,
)
except WorkflowAgentPluginToolsBuildError as error:
raise WorkflowAgentRuntimeRequestBuildError(error.error_code, str(error)) from error
if tools_layer is not None:
metadata["agent_tools"] = {
"dify_tool_count": len(tools_layer.tools),
"dify_tool_names": [tool.name or tool.tool_name for tool in tools_layer.tools],
"cli_tool_count": len(agent_soul.tools.cli_tools),
}
request = self._request_builder.build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id=context.dify_context.tenant_id,
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
model=agent_soul.model.model,
user_id=context.dify_context.user_id,
credentials=self._normalize_credentials(credentials),
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
),
execution_context=ExecutionContext(
# The execution-context layer is now the only public protocol
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
# be forwarded here because downstream plugin-daemon provider and
# tool clients read it from this layer rather than from any
# parallel top-level request field.
execution_context=DifyExecutionContextLayerConfig(
tenant_id=context.dify_context.tenant_id,
user_id=context.dify_context.user_id,
app_id=context.dify_context.app_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,
@ -121,6 +157,7 @@ class WorkflowAgentRuntimeRequestBuilder:
workflow_node_job_prompt=workflow_job_prompt,
user_prompt=user_prompt,
output=self._build_output_config(node_job.declared_outputs),
tools=tools_layer,
idempotency_key=self._idempotency_key(context),
metadata=metadata,
)
@ -142,9 +179,13 @@ class WorkflowAgentRuntimeRequestBuilder:
@staticmethod
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
if context.workflow_run_id:
return f"{context.workflow_run_id}:{context.node_execution_id}"
return context.node_execution_id
# Stage 4 §7 / D-4: retries get distinct keys (``...:retry-{attempt}``) so
# the Agent backend's protocol-level dedup can't replay a previous run.
return retry_idempotency_key(
workflow_run_id=context.workflow_run_id,
node_execution_id=context.node_execution_id,
attempt=context.attempt,
)
@staticmethod
def _build_metadata(
@ -237,11 +278,17 @@ class WorkflowAgentRuntimeRequestBuilder:
@staticmethod
def _build_output_config(declared_outputs: Sequence[DeclaredOutputConfig]) -> AgentBackendOutputConfig | None:
if not declared_outputs:
return None
"""Build the structured-output layer config sent to Agent backend.
Stage 4 §4.1 (D-3): when the user hasn't declared any outputs, inject the
PRD-mandated defaults (text / files / json) at runtime so the backend
always receives a stable schema and the downstream Inspector + nodes
have consistent output names. The defaults are NOT persisted.
"""
effective_outputs = WorkflowAgentRuntimeRequestBuilder.effective_declared_outputs(declared_outputs)
properties: dict[str, Any] = {}
required: list[str] = []
for output in declared_outputs:
for output in effective_outputs:
properties[output.name] = WorkflowAgentRuntimeRequestBuilder._schema_for_declared_output(output)
if output.required:
required.append(output.name)
@ -250,21 +297,52 @@ class WorkflowAgentRuntimeRequestBuilder:
schema["required"] = required
return AgentBackendOutputConfig(json_schema=schema)
@staticmethod
def effective_declared_outputs(
declared_outputs: Sequence[DeclaredOutputConfig],
) -> Sequence[DeclaredOutputConfig]:
"""Alias for :func:`models.agent_config_entities.effective_declared_outputs`.
Kept as a static method on the builder so existing call sites
(``agent_node._run``, tests) don't need to change their import.
"""
return _effective_declared_outputs(list(declared_outputs))
@staticmethod
def _schema_for_declared_output(output: DeclaredOutputConfig) -> dict[str, Any]:
match output.type:
schema = WorkflowAgentRuntimeRequestBuilder._schema_for_type(output.type, array_item=output.array_item)
if output.description:
schema["description"] = output.description
return schema
@staticmethod
def _schema_for_type(
output_type: DeclaredOutputType,
*,
array_item: DeclaredArrayItem | None = None,
) -> dict[str, Any]:
match output_type:
case DeclaredOutputType.STRING:
schema: dict[str, Any] = {"type": "string"}
return {"type": "string"}
case DeclaredOutputType.NUMBER:
schema = {"type": "number"}
return {"type": "number"}
case DeclaredOutputType.BOOLEAN:
schema = {"type": "boolean"}
return {"type": "boolean"}
case DeclaredOutputType.OBJECT:
schema = {"type": "object"}
return {"type": "object"}
case DeclaredOutputType.ARRAY:
schema = {"type": "array"}
# Stage 4 §4.2: items shape mirrors the declared array_item.
# Validator guarantees array_item is set when type is array.
item_type = array_item.type if array_item else DeclaredOutputType.OBJECT
schema: dict[str, Any] = {
"type": "array",
"items": WorkflowAgentRuntimeRequestBuilder._schema_for_type(item_type),
}
if array_item is not None and array_item.description:
schema["items"]["description"] = array_item.description
return schema
case DeclaredOutputType.FILE:
schema = {
return {
"type": "object",
"properties": {
"file_id": {"type": "string"},
@ -273,9 +351,6 @@ class WorkflowAgentRuntimeRequestBuilder:
"url": {"type": "string"},
},
}
if output.description:
schema["description"] = output.description
return schema
@staticmethod
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:

View File

@ -126,6 +126,7 @@ class WorkflowAgentNodeValidator:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} requires Agent Soul model config."
)
cls._validate_agent_soul_tools(binding=binding, agent_soul=agent_soul)
node_job = WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict)
cls.validate_node_job(session=session, binding=binding, node_job=node_job, topology=topology)
@ -147,14 +148,15 @@ class WorkflowAgentNodeValidator:
f"Workflow Agent node {binding.node_id} has duplicate output name {output.name}."
)
output_names.add(output.name)
for check in output.checks:
if check.benchmark_file_ref is not None:
cls._validate_file_ref(
session=session,
binding=binding,
file_ref=check.benchmark_file_ref,
ref_context=f"output {output.name} benchmark file",
)
# Stage 4 §4.3: declared output carries a single optional check, gated by
# ``check.enabled``. Only enabled checks need their benchmark file resolved.
if output.check is not None and output.check.enabled and output.check.benchmark_file_ref is not None:
cls._validate_file_ref(
session=session,
binding=binding,
file_ref=output.check.benchmark_file_ref,
ref_context=f"output {output.name} benchmark file",
)
for ref in node_job.previous_node_output_refs:
selector = cls.selector_from_ref(ref)
@ -279,6 +281,26 @@ class WorkflowAgentNodeValidator:
f"Workflow Agent node {binding.node_id} references unsupported human contact channel {channel}."
)
@classmethod
def _validate_agent_soul_tools(
cls,
*,
binding: WorkflowAgentNodeBinding,
agent_soul: AgentSoulConfig,
) -> None:
exposed_names: set[str] = set()
for tool in agent_soul.tools.dify_tools:
if not tool.enabled:
continue
exposed_name = tool.tool_name
if exposed_name in exposed_names:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} has duplicate Dify Plugin Tool name {exposed_name}."
)
exposed_names.add(exposed_name)
# CLI tools remain saved-but-not-executed. They are allowed at publish
# time so existing Agent Soul drafts are not blocked by a reserved field.
@staticmethod
def _validate_file_ref(
*,

View File

@ -45,6 +45,7 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
)
@ -161,6 +162,8 @@ def create_spec_app() -> Flask:
from controllers.console import bp as console_bp
from controllers.console import console_ns
from controllers.openapi import bp as openapi_bp
from controllers.openapi import openapi_ns
from controllers.service_api import bp as service_api_bp
from controllers.service_api import service_api_ns
from controllers.web import bp as web_bp
@ -169,8 +172,9 @@ def create_spec_app() -> Flask:
app.register_blueprint(console_bp)
app.register_blueprint(web_bp)
app.register_blueprint(service_api_bp)
app.register_blueprint(openapi_bp)
for namespace in (console_ns, web_ns, service_api_ns):
for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
for api in namespace.apis:
_materialize_inline_model_definitions(api)
@ -201,6 +205,13 @@ def _registered_models(namespace: str) -> dict[str, object]:
for api in service_api_ns.apis:
models.update(api.models)
return models
if namespace == "openapi":
from controllers.openapi import openapi_ns
models = dict(openapi_ns.models)
for api in openapi_ns.apis:
models.update(api.models)
return models
raise ValueError(f"unknown Swagger namespace: {namespace}")

View File

@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
OPENAPI_MAX_AGE_SECONDS: int = 600
def _apply_cors_once(bp, /, **cors_kwargs):
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
from controllers.mcp import bp as mcp_bp
from controllers.openapi import bp as openapi_bp
from controllers.service_api import bp as service_api_bp
from controllers.trigger import bp as trigger_bp
from controllers.web import bp as web_bp
@ -41,6 +44,23 @@ def init_app(app: DifyApp):
)
app.register_blueprint(service_api_bp)
if dify_config.OPENAPI_ENABLED:
# User-scoped programmatic API. Default empty allowlist = same-origin
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
# integrations. supports_credentials so cookie-authed approve/deny
# work; cross-origin OPTIONS without an allowed origin will fail
# the same as on the console blueprint.
_apply_cors_once(
openapi_bp,
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=list(OPENAPI_HEADERS),
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
expose_headers=list(EXPOSED_HEADERS),
max_age=OPENAPI_MAX_AGE_SECONDS,
)
app.register_blueprint(openapi_bp)
_apply_cors_once(
web_bp,
resources={

View File

@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
"schedule": crontab(minute="0", hour="0"),
}
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
imports.append("schedule.clean_oauth_access_tokens_task")
beat_schedule["clean_oauth_access_tokens_task"] = {
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
}
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
imports.append("schedule.workflow_schedule_task")
beat_schedule["workflow_schedule_task"] = {

View File

@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token, extract_webapp_passport
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "openapi":
# Account-branch device-flow approval routes (approve / deny /
# approval-context) sit under @login_required and authenticate via
# the console session cookie. Cookie-only on purpose — bearer
# tokens (dfoa_/dfoe_) live on the Authorization header and are
# validated by AppPipeline, not flask-login.
cookie_token = extract_console_cookie_token(request)
if not cookie_token:
return None
try:
decoded = PassportService().verify(cookie_token)
except Exception:
return None
user_id = decoded.get("user_id")
source = decoded.get("token_source")
if source or not user_id:
return None
return AccountService.load_logged_in_account(account_id=user_id)
elif request.blueprint == "web":
app_code = request.headers.get(HEADER_NAME_APP_CODE)
webapp_token = extract_webapp_passport(app_code, request) if app_code else None

View File

@ -0,0 +1,23 @@
"""Bind the bearer authenticator at startup. Must run after ext_database
and ext_redis (needs both factories).
"""
from __future__ import annotations
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import build_and_bind
def is_enabled() -> bool:
return dify_config.ENABLE_OAUTH_BEARER
def init_app(app: DifyApp) -> None:
# scoped_session isn't a context manager; request teardown closes it.
def session_factory():
return db.session
build_and_bind(session_factory=session_factory, redis_client=redis_client)

View File

@ -1,7 +1,7 @@
from datetime import datetime
from flask_restx import fields
from pydantic import field_validator
from pydantic import Field, field_validator
from fields.base import ResponseModel
from libs.helper import TimestampField, to_timestamp
@ -152,31 +152,41 @@ class DatasetRerankingModelResponse(ResponseModel):
class DatasetKeywordSettingResponse(ResponseModel):
keyword_weight: float
keyword_weight: float | None = None
class DatasetVectorSettingResponse(ResponseModel):
vector_weight: float
embedding_model_name: str
embedding_provider_name: str
vector_weight: float | None = None
embedding_model_name: str | None = None
embedding_provider_name: str | None = None
class DatasetWeightedScoreResponse(ResponseModel):
weight_type: str | None
keyword_setting: DatasetKeywordSettingResponse | None
vector_setting: DatasetVectorSettingResponse | None
weight_type: str | None = None
keyword_setting: DatasetKeywordSettingResponse = Field(default_factory=DatasetKeywordSettingResponse)
vector_setting: DatasetVectorSettingResponse = Field(default_factory=DatasetVectorSettingResponse)
@field_validator("keyword_setting", "vector_setting", mode="before")
@classmethod
def _expand_null_nested(cls, value: object) -> object:
return {} if value is None else value
class DatasetRetrievalModelResponse(ResponseModel):
search_method: str
reranking_enable: bool
reranking_mode: str | None = None
reranking_model: DatasetRerankingModelResponse | None
reranking_model: DatasetRerankingModelResponse = Field(default_factory=DatasetRerankingModelResponse)
weights: DatasetWeightedScoreResponse | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: float | None = None
@field_validator("reranking_model", mode="before")
@classmethod
def _expand_null_nested(cls, value: object) -> object:
return {} if value is None else value
class DatasetSummaryIndexSettingResponse(ResponseModel):
enable: bool | None = None
@ -192,15 +202,15 @@ class DatasetTagResponse(ResponseModel):
class DatasetExternalKnowledgeInfoResponse(ResponseModel):
external_knowledge_id: str
external_knowledge_api_id: str
external_knowledge_api_name: str
external_knowledge_api_endpoint: str
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None
external_knowledge_api_name: str | None = None
external_knowledge_api_endpoint: str | None = None
class DatasetExternalRetrievalModelResponse(ResponseModel):
top_k: int
score_threshold: float
score_threshold: float | None = None
score_threshold_enabled: bool | None = None
@ -211,8 +221,8 @@ class DatasetDocMetadataResponse(ResponseModel):
class DatasetIconInfoResponse(ResponseModel):
icon_type: str | None
icon: str | None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
icon_url: str | None = None
@ -237,17 +247,21 @@ class DatasetDetailResponse(ResponseModel):
embedding_model_provider: str | None
embedding_available: bool | None = None
retrieval_model_dict: DatasetRetrievalModelResponse
summary_index_setting: DatasetSummaryIndexSettingResponse | None
summary_index_setting: DatasetSummaryIndexSettingResponse = Field(
default_factory=DatasetSummaryIndexSettingResponse
)
tags: list[DatasetTagResponse]
doc_form: str | None
external_knowledge_info: DatasetExternalKnowledgeInfoResponse | None
external_knowledge_info: DatasetExternalKnowledgeInfoResponse = Field(
default_factory=DatasetExternalKnowledgeInfoResponse
)
external_retrieval_model: DatasetExternalRetrievalModelResponse | None
doc_metadata: list[DatasetDocMetadataResponse]
built_in_field_enabled: bool
pipeline_id: str | None
runtime_mode: str | None
chunk_structure: str | None
icon_info: DatasetIconInfoResponse | None
icon_info: DatasetIconInfoResponse = Field(default_factory=DatasetIconInfoResponse)
is_published: bool
total_documents: int
total_available_documents: int
@ -258,3 +272,8 @@ class DatasetDetailResponse(ResponseModel):
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
@field_validator("summary_index_setting", "external_knowledge_info", "icon_info", mode="before")
@classmethod
def _expand_null_nested(cls, value: object) -> object:
return {} if value is None else value

View File

@ -0,0 +1,205 @@
"""Device-flow security primitives: enterprise_only gate, approval-grant
cookie mint/verify/consume, and anti-framing headers.
"""
from __future__ import annotations
import logging
import secrets
from collections.abc import Callable
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from functools import wraps
from flask import Blueprint
from pydantic import ValidationError
from werkzeug.exceptions import NotFound
from libs import jws
from libs.token import is_secure
from services.feature_service import FeatureService, LicenseStatus
logger = logging.getLogger(__name__)
# ============================================================================
# enterprise_only decorator
# ============================================================================
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
responses don't consume the bucket.
"""
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
if settings.license.status not in _EE_ENABLED_STATUSES:
raise NotFound()
return view(*args, **kwargs)
return decorated
# ============================================================================
# approval_grant cookie
# ============================================================================
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
@dataclass(frozen=True, slots=True)
class ApprovalGrantClaims:
subject_email: str
subject_issuer: str
user_code: str
nonce: str
csrf_token: str
expires_at: datetime
def mint_approval_grant(
*,
keyset: jws.KeySet,
iss: str,
subject_email: str,
subject_issuer: str,
user_code: str,
) -> tuple[str, ApprovalGrantClaims]:
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
source of truth for Path/HttpOnly/Secure/SameSite.
"""
now = datetime.now(UTC)
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
nonce = _random_opaque()
csrf_token = _random_opaque()
payload = {
"iss": iss,
"subject_email": subject_email,
"subject_issuer": subject_issuer,
"user_code": user_code,
"nonce": nonce,
"csrf_token": csrf_token,
}
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
return token, ApprovalGrantClaims(
subject_email=subject_email,
subject_issuer=subject_issuer,
user_code=user_code,
nonce=nonce,
csrf_token=csrf_token,
expires_at=exp,
)
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
"""Sig + aud + exp only — nonce consumption is the caller's job."""
# lazy import: breaks libs → controllers cycle
from controllers.openapi._models import ApprovalGrantClaimsPayload
raw = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
try:
parsed = ApprovalGrantClaimsPayload.model_validate(raw)
except ValidationError as e:
raise jws.VerifyError(f"claim shape invalid: {e}") from e
return ApprovalGrantClaims(
subject_email=parsed.subject_email,
subject_issuer=parsed.subject_issuer,
user_code=parsed.user_code,
nonce=parsed.nonce,
csrf_token=parsed.csrf_token,
expires_at=datetime.fromtimestamp(raw["exp"], tz=UTC),
)
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
if not nonce:
return False
return bool(
redis_client.set(
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
"1",
nx=True,
ex=NONCE_TTL_SECONDS,
)
)
def approval_grant_cookie_kwargs(value: str) -> dict:
"""``secure`` follows is_secure() so HTTP-only deployments don't
silently drop the cookie.
"""
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": value,
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def approval_grant_cleared_cookie_kwargs() -> dict:
return {
"key": APPROVAL_GRANT_COOKIE_NAME,
"value": "",
"max_age": 0,
"path": APPROVAL_GRANT_COOKIE_PATH,
"secure": is_secure(),
"httponly": True,
"samesite": "Lax",
}
def _random_opaque() -> str:
return secrets.token_urlsafe(16)
# ============================================================================
# Anti-framing headers
# ============================================================================
_ANTI_FRAMING_HEADERS = {
"X-Frame-Options": "DENY",
"Content-Security-Policy": "frame-ancestors 'none'",
}
def attach_anti_framing(bp: Blueprint) -> None:
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
@bp.after_request
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
for name, value in _ANTI_FRAMING_HEADERS.items():
response.headers.setdefault(name, value)
return response

View File

@ -76,6 +76,7 @@ def register_external_error_handlers(api: Api):
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
current_app.logger.exception("value_error in request handler")
status_code = 400
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code

View File

@ -595,3 +595,18 @@ class RateLimiter:
self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2)
def seconds_until_available(self, email: str) -> int:
"""Seconds until the oldest in-window entry expires, freeing a slot.
Defensive floor of 1 second. Caller should only invoke this after
is_rate_limited() returned True.
"""
key = self._get_key(email)
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
if not oldest:
return 1
_member, score = oldest[0]
free_at = int(score) + self.time_window
remaining = free_at - int(time.time())
return max(remaining, 1)

108
api/libs/jws.py Normal file
View File

@ -0,0 +1,108 @@
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
state envelope, external subject assertion, and approval-grant cookie —
all three share one key-set so api ↔ enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
ACTIVE_KID_V1 = "dify-shared-v1"
class KeySetError(Exception):
pass
class KeySet:
"""``from_entries`` reserves multi-kid construction for rotation slots."""
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
if active_kid not in entries:
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
if not entries[active_kid]:
raise KeySetError(f"active kid {active_kid!r} has empty secret")
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> KeySet:
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
return cls(entries, active_kid)
@property
def active_kid(self) -> str:
return self._active_kid
def lookup(self, kid: str) -> bytes | None:
return self._entries.get(kid)
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
"""``iat`` + ``exp`` are injected here; callers must not set them."""
if "aud" in payload or "iat" in payload or "exp" in payload:
raise ValueError("reserved claim present in payload (aud/iat/exp)")
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
kid = keyset.active_kid
secret = keyset.lookup(kid)
if secret is None:
raise KeySetError(f"active kid {kid!r} lookup miss")
iat = datetime.now(UTC)
exp = iat + timedelta(seconds=ttl_seconds)
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
return jwt.encode(
claims,
secret,
algorithm="HS256",
headers={"kid": kid, "typ": "JWT"},
)
class VerifyError(Exception):
pass
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
"""Unknown kid is rejected — never fall back to the active kid, since
a past kid value would otherwise be forgeable by anyone who saw it.
"""
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as e:
raise VerifyError(f"decode header: {e}") from e
kid = header.get("kid")
if not kid:
raise VerifyError("no kid in header")
secret = keyset.lookup(kid)
if secret is None:
raise VerifyError(f"unknown kid {kid!r}")
try:
return jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=expected_aud,
)
except jwt.ExpiredSignatureError as e:
raise VerifyError("token expired") from e
except jwt.InvalidAudienceError as e:
raise VerifyError("aud mismatch") from e
except jwt.PyJWTError as e:
raise VerifyError(f"decode: {e}") from e

685
api/libs/oauth_bearer.py Normal file
View File

@ -0,0 +1,685 @@
"""OAuth bearer primitives.
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
Authenticator + validate_bearer stay untouched.
"""
from __future__ import annotations
import hashlib
import json
import logging
import uuid
from collections.abc import Callable, Iterable
from contextvars import ContextVar, Token
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
from functools import wraps
from typing import Literal, ParamSpec, Protocol, TypeVar
from flask import request
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.rate_limit import enforce_bearer_rate_limit
from models import Account, OAuthAccessToken, TenantAccountJoin
logger = logging.getLogger(__name__)
# ============================================================================
# Contract — types, enums, protocols
# ============================================================================
class SubjectType(StrEnum):
ACCOUNT = "account"
EXTERNAL_SSO = "external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
(`APPS_RUN`, `APPS_READ_PERMITTED_EXTERNAL`).
"""
FULL = "full"
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
class Accepts(StrEnum):
"""Subject types a route is willing to accept as caller."""
USER_ACCOUNT = "user_account"
USER_EXT_SSO = "user_ext_sso"
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
}
@dataclass(frozen=True, slots=True)
class AuthContext:
"""Per-request identity published via :data:`_auth_ctx_var`
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
``subject_type`` / ``source`` come from the TokenKind, not the DB —
corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
authenticate time. Per-request mutations write through to Redis via
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
"""
subject_type: SubjectType
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
_auth_ctx_var: ContextVar[AuthContext] = ContextVar("openapi_auth_ctx")
def set_auth_ctx(ctx: AuthContext) -> Token[AuthContext]:
return _auth_ctx_var.set(ctx)
def reset_auth_ctx(token: Token[AuthContext]) -> None:
_auth_ctx_var.reset(token)
def get_auth_ctx() -> AuthContext:
return _auth_ctx_var.get()
def try_get_auth_ctx() -> AuthContext | None:
return _auth_ctx_var.get(None)
@dataclass(frozen=True, slots=True)
class ResolvedRow:
subject_email: str | None
subject_issuer: str | None
account_id: uuid.UUID | None
client_id: str | None
token_id: uuid.UUID
expires_at: datetime | None
verified_tenants: dict[str, bool] = field(default_factory=dict)
def to_cache(self) -> dict:
return {
"subject_email": self.subject_email,
"subject_issuer": self.subject_issuer,
"account_id": str(self.account_id) if self.account_id else None,
"client_id": self.client_id,
"token_id": str(self.token_id),
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"verified_tenants": dict(self.verified_tenants),
}
@classmethod
def from_cache(cls, data: dict) -> ResolvedRow:
return cls(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
client_id=data.get("client_id"),
token_id=uuid.UUID(data["token_id"]),
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
)
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
on all live deployments (60s TTL → safe to drop one release after rollout).
"""
if not isinstance(raw, dict):
return {}
out: dict[str, bool] = {}
for k, v in raw.items():
if isinstance(v, bool):
out[k] = v
elif v == "ok":
out[k] = True
elif v == "denied":
out[k] = False
return out
class Resolver(Protocol):
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
...
@dataclass(frozen=True, slots=True)
class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
source: str
resolver: Resolver
def matches(self, token: str) -> bool:
return token.startswith(self.prefix)
@dataclass(frozen=True, slots=True)
class MintProfile:
"""Single source of truth for (subject_type, prefix, scopes) at mint time.
Consumers:
- ``build_registry`` reads scopes here so the resolve-time TokenKind
cannot drift from the mint-time intent.
- Device-flow ``approve`` / ``approve-external`` read prefix + scopes
here when calling ``mint_oauth_token`` and ``validate_mint_policy``.
- ``services.openapi.mint_policy.validate_mint_policy`` cross-checks
the (subject_type, prefix, scopes) triple a caller intends to mint
against this table — a caller that assembles its own scope set
from a non-canonical source will fail closed at approve time.
"""
subject_type: SubjectType
prefix: str
scopes: frozenset[Scope]
MINTABLE_PROFILES: dict[SubjectType, MintProfile] = {
SubjectType.ACCOUNT: MintProfile(
subject_type=SubjectType.ACCOUNT,
prefix="dfoa_",
scopes=frozenset({Scope.FULL}),
),
SubjectType.EXTERNAL_SSO: MintProfile(
subject_type=SubjectType.EXTERNAL_SSO,
prefix="dfoe_",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
),
}
class InvalidBearerError(Exception):
"""Token missing, unknown prefix, or no live row."""
class TokenExpiredError(Exception):
"""Hard-expire bookkeeping is the resolver's job before raising."""
# ============================================================================
# Registry
# ============================================================================
class TokenKindRegistry:
def __init__(self, kinds: Iterable[TokenKind]) -> None:
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
prefixes = [k.prefix for k in self._kinds]
if len(set(prefixes)) != len(prefixes):
raise ValueError(f"duplicate prefix in registry: {prefixes}")
def find(self, token: str) -> TokenKind | None:
for k in self._kinds:
if k.matches(token):
return k
return None
def kinds(self) -> tuple[TokenKind, ...]:
return self._kinds
# ============================================================================
# Authenticator
# ============================================================================
def sha256_hex(token: str) -> str:
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class BearerAuthenticator:
def __init__(self, registry: TokenKindRegistry) -> None:
self._registry = registry
@property
def registry(self) -> TokenKindRegistry:
return self._registry
def authenticate(self, token: str) -> AuthContext:
"""Identity + per-token rate limit (single source).
Both the openapi pipeline (`BearerCheck`) and the decorator
(`validate_bearer`) call this — rate-limit fires exactly once per
request regardless of which path hosts the route.
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("invalid_bearer")
token_hash = sha256_hex(token)
enforce_bearer_rate_limit(token_hash)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("invalid_bearer")
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=row.account_id,
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
)
# ============================================================================
# OAuth access token resolver (PAT resolver would be a sibling class)
# ============================================================================
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
POSITIVE_TTL_SECONDS = 60
NEGATIVE_TTL_SECONDS = 10
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
ScopeVariant = Literal["account", "external_sso"]
class OAuthAccessTokenResolver:
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
sharing DB + cache plumbing.
"""
def __init__(
self,
session_factory,
redis_client,
positive_ttl: int = POSITIVE_TTL_SECONDS,
negative_ttl: int = NEGATIVE_TTL_SECONDS,
) -> None:
self.session_factory = session_factory
self._redis = redis_client
self._positive_ttl = positive_ttl
self._negative_ttl = negative_ttl
def for_account(self) -> Resolver:
return _VariantResolver(self, variant="account")
def for_external_sso(self) -> Resolver:
return _VariantResolver(self, variant="external_sso")
def _cache_key(self, token_hash: str) -> str:
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
raw = self._redis.get(self._cache_key(token_hash))
if raw is None:
return None
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return "invalid"
try:
return ResolvedRow.from_cache(json.loads(text))
except (ValueError, KeyError):
logger.warning("auth:token cache entry malformed; treating as miss")
return None
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
self._redis.setex(
self._cache_key(token_hash),
self._positive_ttl,
json.dumps(row.to_cache()),
)
def cache_set_negative(self, token_hash: str) -> None:
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
"""Atomic CAS — only the worker that flips revoked_at emits audit;
replays are idempotent.
"""
stmt = (
update(OAuthAccessToken)
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
.values(revoked_at=datetime.now(UTC), token_hash=None)
)
result = session.execute(stmt)
session.commit()
if result.rowcount == 1: # type: ignore
logger.warning(
"audit: %s token_id=%s",
AUDIT_OAUTH_EXPIRED,
row_id,
extra={"audit": True, "token_id": str(row_id)},
)
self._redis.delete(self._cache_key(token_hash))
self.cache_set_negative(token_hash)
class _VariantResolver:
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
self._parent = parent
self._variant = variant
def resolve(self, token_hash: str) -> ResolvedRow | None:
cached = self._parent.cache_get(token_hash)
if cached == "invalid":
return None
if cached is not None and not isinstance(cached, str):
if not self._matches_variant(cached):
return None
return cached
# Flask-SQLAlchemy's scoped_session is request-bound and not a
# context manager; use it directly.
session = self._parent.session_factory()
row = self._load_from_db(session, token_hash)
if row is None:
self._parent.cache_set_negative(token_hash)
return None
now = datetime.now(UTC)
if row.expires_at is not None and row.expires_at <= now:
self._parent.hard_expire(session, row.id, token_hash)
return None
if not self._matches_variant_model(row):
logger.error(
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
row.id,
row.prefix,
)
return None
resolved = ResolvedRow(
subject_email=row.subject_email,
subject_issuer=row.subject_issuer,
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
client_id=row.client_id,
token_id=uuid.UUID(str(row.id)),
expires_at=row.expires_at,
)
self._parent.cache_set_positive(token_hash, resolved)
return resolved
def _matches_variant(self, row: ResolvedRow) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account
return not has_account
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
has_account = row.account_id is not None
if self._variant == "account":
return has_account and row.prefix == "dfoa_"
return (not has_account) and row.prefix == "dfoe_"
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
return (
session.query(OAuthAccessToken)
.filter(
OAuthAccessToken.token_hash == token_hash,
OAuthAccessToken.revoked_at.is_(None),
)
.one_or_none()
)
# ============================================================================
# Layer 0 — workspace membership cache + helper
# ============================================================================
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
rebuilds via authenticate() and re-runs Layer 0.
"""
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
raw = redis_client.get(cache_key)
if raw is None:
return
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
if text == "invalid":
return
try:
data = json.loads(text)
except (ValueError, KeyError):
return
ttl = redis_client.ttl(cache_key)
if ttl <= 0:
return
data.setdefault("verified_tenants", {})[tenant_id] = verdict
redis_client.setex(cache_key, ttl, json.dumps(data))
def check_workspace_membership(
*,
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
inline helper (`require_workspace_member`). Caller is responsible for
short-circuiting on EE / SSO subjects before invoking — this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
if cached is True:
return
if cached is False:
raise Forbidden("workspace_membership_revoked")
join = db.session.execute(
select(TenantAccountJoin.id).where(
TenantAccountJoin.account_id == account_id,
TenantAccountJoin.tenant_id == tenant_id,
)
).scalar_one_or_none()
if join is None:
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
if status != "active":
record_layer0_verdict(token_hash, tenant_id, False)
raise Forbidden("workspace_membership_revoked")
record_layer0_verdict(token_hash, tenant_id, True)
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
(no `tenant_account_joins` row by definition).
"""
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
return
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
)
# ============================================================================
# Decorator — route-level bearer gate
# ============================================================================
_authenticator: BearerAuthenticator | None = None
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
global _authenticator
_authenticator = authenticator
def get_authenticator() -> BearerAuthenticator:
if _authenticator is None:
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
return _authenticator
def extract_bearer(req) -> str | None:
"""Pull the bearer token out of an HTTP request's Authorization header.
Used by both attachment paths (the ``validate_bearer`` decorator and the
openapi ``Pipeline.guard``) so the parsing rule lives in one place. Pipeline
callers extract once at the boundary and pass the token through ``Context``
so steps stay independent of the request object.
"""
header = req.headers.get("Authorization", "")
scheme, _, value = header.partition(" ")
if scheme.lower() != "bearer" or not value:
return None
return value.strip()
_DP = ParamSpec("_DP")
_DR = TypeVar("_DR")
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
"""Opt-in: omitting it leaves the route unauthenticated.
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
and are rejected here as the wrong auth scheme for this surface.
"""
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
@wraps(fn)
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
token = extract_bearer(request)
if token is None:
raise Unauthorized("missing bearer token")
if _authenticator is None:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
try:
ctx = get_authenticator().authenticate(token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
raise Forbidden("token subject type not accepted here")
# Try/finally pairing — the WSGI worker thread is reused
# across requests, so a leaked ContextVar would publish the
# previous caller's identity to the next request.
reset_token = set_auth_ctx(ctx)
try:
return fn(*args, **kwargs)
finally:
reset_auth_ctx(reset_token)
return inner
return wrap
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
without the authenticator, so fail fast instead of approving silently.
"""
@wraps(fn)
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
if not dify_config.ENABLE_OAUTH_BEARER:
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
return fn(*args, **kwargs)
return inner
def require_scope(scope: Scope) -> Callable:
"""Route-level scope gate — must run AFTER validate_bearer so that
the auth ContextVar is set. Raises ``Forbidden('insufficient_scope: <scope>')``
when the bearer lacks both the requested scope and ``Scope.FULL``.
"""
def wrap(fn: Callable) -> Callable:
@wraps(fn)
def inner(*args, **kwargs):
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
)
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
raise Forbidden(f"insufficient_scope: {scope}")
return fn(*args, **kwargs)
return inner
return wrap
# ============================================================================
# Wiring — called once from the app factory
# ============================================================================
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
account = MINTABLE_PROFILES[SubjectType.ACCOUNT]
external = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
return TokenKindRegistry(
[
TokenKind(
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
source="oauth_account",
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
source="oauth_external_sso",
resolver=oauth.for_external_sso(),
),
]
)
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
registry = build_registry(session_factory, redis_client)
auth = BearerAuthenticator(registry)
bind_authenticator(auth)
return auth

147
api/libs/rate_limit.py Normal file
View File

@ -0,0 +1,147 @@
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
window Redis ZSET). Apply after auth decorators so account/email/token-id
scopes can read the openapi auth ContextVar (see
:func:`libs.oauth_bearer.try_get_auth_ctx`). Use :func:`enforce` when the
bucket key is computed in-handler. RFC-8628 ``slow_down`` is inline — its
response shape isn't generic 429.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from enum import StrEnum
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import jsonify, make_response, request, session
from werkzeug.exceptions import TooManyRequests
from configs import dify_config
from libs.helper import RateLimiter, extract_remote_ip
class RateLimitScope(StrEnum):
IP = "ip"
SESSION = "session"
ACCOUNT = "account"
SUBJECT_EMAIL = "subject_email"
TOKEN_ID = "token_id"
@dataclass(frozen=True, slots=True)
class RateLimit:
limit: int
window: timedelta
scopes: tuple[RateLimitScope, ...]
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
LIMIT_BEARER_PER_TOKEN = RateLimit(
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
window=timedelta(minutes=1),
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
)
def _one_key(scope: RateLimitScope) -> str:
match scope:
case RateLimitScope.IP:
return f"ip:{extract_remote_ip(request) or 'unknown'}"
case RateLimitScope.SESSION:
return f"session:{session.get('_id', 'anon')}"
case RateLimitScope.ACCOUNT:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.account_id:
return f"account:{ctx.account_id}"
return "account:anon"
case RateLimitScope.SUBJECT_EMAIL:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.subject_email:
return f"subject:{ctx.subject_email}"
return "subject:anon"
case RateLimitScope.TOKEN_ID:
from libs.oauth_bearer import try_get_auth_ctx
ctx = try_get_auth_ctx()
if ctx and ctx.token_id:
return f"token:{ctx.token_id}"
return "token:anon"
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
return "|".join(_one_key(s) for s in scopes)
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
return "rl:" + "+".join(s.value for s in scopes)
def _build_limiter(spec: RateLimit) -> RateLimiter:
return RateLimiter(
prefix=_limiter_prefix(spec.scopes),
max_attempts=spec.limit,
time_window=int(spec.window.total_seconds()),
)
_P = ParamSpec("_P")
_R = TypeVar("_R")
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Apply after auth decorators that the scopes read from."""
limiter = _build_limiter(spec)
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
key = _composite_key(spec.scopes)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
return fn(*args, **kwargs)
return inner
return wrap
def enforce(spec: RateLimit, *, key: str) -> None:
"""Imperative form — caller composes the bucket key to match scope
semantics (the key is opaque here).
"""
limiter = _build_limiter(spec)
if limiter.is_rate_limited(key):
raise TooManyRequests("rate_limited")
limiter.increment_rate_limit(key)
def enforce_bearer_rate_limit(token_hash: str) -> None:
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
Bucket key = ``token:<sha256_hex>`` so the same token shares one
bucket across api replicas (Redis-backed sliding window).
"""
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
key = f"token:{token_hash}"
if limiter.is_rate_limited(key):
retry_after = limiter.seconds_until_available(key)
response = make_response(
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
429,
)
response.headers["Retry-After"] = str(retry_after)
raise TooManyRequests(response=response)
limiter.increment_rate_limit(key)

View File

@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
def extract_console_cookie_token(request: Request) -> str | None:
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
approval routes, which must not fall through to the Authorization header
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_access_token(request: Request) -> str | None:
return extract_console_cookie_token(request) or _try_extract_from_header(request)
def extract_webapp_access_token(request: Request) -> str | None:

Some files were not shown because too many files have changed in this diff Show More