mirror of
https://github.com/langgenius/dify.git
synced 2026-05-26 20:07:46 +08:00
Compare commits
26 Commits
song/eng-3
...
codex/alig
| Author | SHA1 | Date | |
|---|---|---|---|
| 6a06bb45b3 | |||
| 34b19422a2 | |||
| 323b2b82e0 | |||
| 7d45335a32 | |||
| f5d664887b | |||
| 5aa24c25d9 | |||
| eed8d659d1 | |||
| 59e99ee1ae | |||
| 533929d314 | |||
| fb07b43107 | |||
| 0dad426101 | |||
| 2a1df4de62 | |||
| 2b97f6c8c2 | |||
| 75d6511284 | |||
| fd059720e5 | |||
| 2a5f7bb1aa | |||
| 0f06aa2fdd | |||
| 884e2b864b | |||
| a728e0ac69 | |||
| 7d464d014c | |||
| 0ce0127e7e | |||
| 25da7ae0d9 | |||
| 4d6f8eba2a | |||
| 87268f0662 | |||
| 135e01930b | |||
| fe86fa31ec |
15
.dockerignore
Normal file
15
.dockerignore
Normal file
@ -0,0 +1,15 @@
|
||||
**/node_modules
|
||||
**/.pnpm-store
|
||||
**/dist
|
||||
**/.next
|
||||
**/.turbo
|
||||
**/.cache
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/.mypy_cache
|
||||
**/.ruff_cache
|
||||
.git
|
||||
.github
|
||||
*.md
|
||||
!web/README.md
|
||||
!api/README.md
|
||||
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -5,3 +5,7 @@
|
||||
# them.
|
||||
|
||||
*.sh text eol=lf
|
||||
|
||||
# Codegen output must stay byte-identical across platforms so
|
||||
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
|
||||
*.generated.ts text eol=lf
|
||||
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,6 +18,10 @@
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# CLI
|
||||
/cli/ @langgenius/maintainers
|
||||
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
|
||||
111
.github/dependabot.yml
vendored
111
.github/dependabot.yml
vendored
@ -110,3 +110,114 @@ updates:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 5
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
88
.github/workflows/cli-release.yml
vendored
Normal file
88
.github/workflows/cli-release.yml
vendored
Normal file
@ -0,0 +1,88 @@
|
||||
name: CLI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- 'difyctl-v*'
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: build standalone binaries (all targets)
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Install cross-arch native prebuilds
|
||||
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||
# so `bun build --compile` can embed the right .node into each target.
|
||||
working-directory: ./
|
||||
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||
|
||||
- name: Compile standalone binaries (all targets)
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
run: |
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build:bin
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
run: scripts/release-write-checksums.sh
|
||||
|
||||
- name: Publish GitHub Release
|
||||
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||
with:
|
||||
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||
generate_release_notes: true
|
||||
fail_on_unmatched_files: true
|
||||
files: |
|
||||
cli/dist/bin/difyctl-v*
|
||||
60
.github/workflows/cli-smoke.yml
vendored
Normal file
60
.github/workflows/cli-smoke.yml
vendored
Normal file
@ -0,0 +1,60 @@
|
||||
name: CLI Smoke (live dify)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_version:
|
||||
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||
type: string
|
||||
required: true
|
||||
cli_ref:
|
||||
description: "Git ref to build the cli from (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
smoke:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Bring up dify
|
||||
env:
|
||||
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||
run: |
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||
docker compose up -d api worker web db redis
|
||||
for i in $(seq 1 60); do
|
||||
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||
echo "dify api ready after ${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run smoke against live dify
|
||||
working-directory: ./cli
|
||||
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||
|
||||
- name: Dump dify logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
cd docker
|
||||
docker compose logs api worker web --tail=200
|
||||
46
.github/workflows/cli-tests.yml
vendored
Normal file
46
.github/workflows/cli-tests.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: CLI Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,6 +42,7 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
@ -62,6 +63,18 @@ jobs:
|
||||
- 'docker/generate_docker_compose'
|
||||
- 'docker/ssrf_proxy/**'
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
cli:
|
||||
- 'cli/**'
|
||||
- 'packages/tsconfig/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- 'eslint.config.mjs'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/cli-tests.yml'
|
||||
- '.github/workflows/cli-docker-build.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
@ -184,6 +197,66 @@ jobs:
|
||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
cli-tests-run:
|
||||
name: Run CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||
uses: ./.github/workflows/cli-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests-skip:
|
||||
name: Skip CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped CLI tests
|
||||
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
if: ${{ always() }}
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
- cli-tests-run
|
||||
- cli-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize CLI Tests status
|
||||
env:
|
||||
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||
run: |
|
||||
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests ran successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests were skipped because no CLI-related files changed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
web-tests-run:
|
||||
name: Run Web Tests
|
||||
needs:
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -115,6 +115,12 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||
!/cli/src/env/
|
||||
!/cli/src/commands/env/
|
||||
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||
!/cli/scripts/lib/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
@ -247,8 +253,9 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
*.local.toml
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.context/*
|
||||
.context/
|
||||
.eslintcache
|
||||
|
||||
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -30,7 +30,7 @@ from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
@ -42,7 +42,7 @@ from clients.agent_backend.request_builder import (
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
|
||||
@ -4,7 +4,9 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
|
||||
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
|
||||
protocol has a single owner. API-only context such as Agent Soul vs workflow job
|
||||
prompt is preserved in layer names and metadata until the dedicated product
|
||||
schemas land in later phases.
|
||||
schemas land in later phases. Dify-owned execution identifiers are emitted as an
|
||||
explicit ``dify.execution_context`` layer so the run request stays fully
|
||||
composition-driven.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -15,18 +17,19 @@ from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLayerConfig,
|
||||
DifyPluginLLMLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
ExecutionContext,
|
||||
LayerExitSignals,
|
||||
RunComposition,
|
||||
RunLayerSpec,
|
||||
@ -37,17 +40,15 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
model_provider: str
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
@ -73,7 +74,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: ExecutionContext
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
workflow_node_job_prompt: str
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
@ -125,21 +126,18 @@ class AgentBackendRunRequestBuilder:
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLayerConfig(
|
||||
tenant_id=run_input.model.tenant_id,
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
user_id=run_input.model.user_id,
|
||||
),
|
||||
config=run_input.execution_context,
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
@ -165,7 +163,6 @@ class AgentBackendRunRequestBuilder:
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
execution_context=run_input.execution_context,
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -23,7 +25,7 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
@ -525,6 +525,44 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
OPENAPI_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||
"programmatic clients. Set to true to activate; disabled by default."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -900,6 +938,17 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1186,6 +1235,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -68,6 +68,7 @@ from .app import (
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_node_output_inspector,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
@ -218,6 +219,7 @@ __all__ = [
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_node_output_inspector",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
|
||||
@ -9,18 +9,25 @@ from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
@ -40,7 +47,7 @@ class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -64,10 +71,11 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
def get(self, resource_id):
|
||||
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
|
||||
|
||||
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
@ -75,13 +83,14 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
|
||||
|
||||
@edit_permission_required
|
||||
def post(self, resource_id):
|
||||
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
|
||||
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
|
||||
|
||||
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
@ -108,7 +117,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
return api_token
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -118,9 +127,20 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
def delete(
|
||||
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
|
||||
) -> tuple[str, int]:
|
||||
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
def _delete_api_key(
|
||||
self,
|
||||
resource_id: str,
|
||||
api_key_id: str,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> None:
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
@ -147,8 +167,6 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@ -156,18 +174,21 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -181,9 +202,14 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -196,18 +222,21 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
def get(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id: UUID):
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
@ -221,9 +250,14 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.app.wraps import get_app_model, with_session
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -26,7 +26,6 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -852,11 +851,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_session
|
||||
@get_app_model
|
||||
def get(self, app_model):
|
||||
def get(self, session: Session, app_model: App):
|
||||
"""Get app trace"""
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
|
||||
415
api/controllers/console/app/workflow_node_output_inspector.py
Normal file
415
api/controllers/console/app/workflow_node_output_inspector.py
Normal file
@ -0,0 +1,415 @@
|
||||
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
|
||||
|
||||
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
|
||||
with a producer-organized view of each node's declared outputs and their
|
||||
per-run status. This module exposes two parallel sets of three read-only
|
||||
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
|
||||
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
|
||||
schedule / plugin triggers). Both sets share the same service code, the same
|
||||
response shapes, and the same error codes; the URL is the *only* difference,
|
||||
so the frontend can pick the right prefix based on which run-detail page the
|
||||
user is on.
|
||||
|
||||
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
|
||||
``published_run_inspector_not_implemented`` 404 code is therefore no longer
|
||||
produced.
|
||||
|
||||
URLs follow the design doc and reuse the existing
|
||||
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
|
||||
:mod:`controllers.console.app.workflow_draft_variable`. The
|
||||
``published`` prefix mirrors it shape-for-shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from services.workflow import inspector_events
|
||||
from services.workflow.node_output_inspector_service import (
|
||||
NodeOutputInspectorError,
|
||||
NodeOutputInspectorService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
|
||||
# intervening proxies (nginx, ingress) don't reap the idle connection.
|
||||
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
|
||||
_HEARTBEAT_EVERY_TICKS = 15
|
||||
# Hard ceiling on a single stream — if we never see a terminal workflow
|
||||
# event (engine crashed, redis dropped the message), force-close after this
|
||||
# many ticks (= seconds).
|
||||
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
|
||||
|
||||
|
||||
def _service() -> NodeOutputInspectorService:
|
||||
"""One-line factory so tests can monkeypatch a stub if needed."""
|
||||
return NodeOutputInspectorService()
|
||||
|
||||
|
||||
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
|
||||
"""Resource-body shared by draft + published snapshot endpoints.
|
||||
|
||||
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
|
||||
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
|
||||
behaviour lives here, where unit tests can hit it without spinning up
|
||||
Flask request context.
|
||||
"""
|
||||
try:
|
||||
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return snapshot.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
|
||||
"""Resource-body shared by draft + published node-detail endpoints."""
|
||||
try:
|
||||
view = _service().node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return view.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
|
||||
"""Resource-body shared by draft + published output-preview endpoints."""
|
||||
try:
|
||||
preview = _service().output_preview(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
output_name=output_name,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return preview.model_dump(mode="json")
|
||||
|
||||
|
||||
class _InspectorNotFound(BaseHTTPException):
|
||||
"""404 that preserves the inspector's specific error code.
|
||||
|
||||
Without this the response body collapses to a generic ``not_found`` code
|
||||
and clients lose the ability to distinguish, e.g.,
|
||||
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
|
||||
"""
|
||||
|
||||
code = 404
|
||||
|
||||
def __init__(self, error: NodeOutputInspectorError) -> None:
|
||||
self.error_code = error.code
|
||||
super().__init__(description=str(error))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowDraftRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot organized by producer node."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowDraftRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output (with signed URL for file refs)."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# SSE event stream — shared generator used by draft + published variants
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
|
||||
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
|
||||
|
||||
``data`` is JSON-serialized when given as a dict; raw strings are
|
||||
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
|
||||
"""
|
||||
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
|
||||
"""Yield SSE-framed strings for one workflow run.
|
||||
|
||||
The stream begins with a full ``snapshot`` event so the client has a
|
||||
starting state without needing a separate REST GET. Then for every
|
||||
``node_changed`` message from the pub/sub channel we re-read that node
|
||||
from DB and push a fresh ``node_changed`` event. When the workflow run
|
||||
reaches a terminal state we push one final ``workflow_run_completed``
|
||||
event and close the stream.
|
||||
|
||||
Failures inside the loop are caught and surfaced as ``error`` events so
|
||||
the frontend can show a banner rather than seeing the connection drop
|
||||
silently. The Inspector never raises across the SSE boundary.
|
||||
"""
|
||||
service = _service()
|
||||
run_id_str = str(run_id)
|
||||
|
||||
# Initial snapshot — also flushes a 404 back at the client right away
|
||||
# if the run is gone (raised before yielding any bytes, so Flask turns it
|
||||
# into the normal HTTP 404 path).
|
||||
try:
|
||||
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
|
||||
event_id = 0
|
||||
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
|
||||
|
||||
# If the run already finished by the time the client connected, emit
|
||||
# the terminal envelope synchronously and close — no point subscribing.
|
||||
# The enum value for partial success is the hyphenated ``partial-succeeded``
|
||||
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
|
||||
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Live subscription
|
||||
ticks_since_heartbeat = 0
|
||||
total_ticks = 0
|
||||
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
|
||||
total_ticks += 1
|
||||
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
|
||||
logger.warning(
|
||||
"Inspector SSE: forcing close after %ds without terminal event for run %s",
|
||||
_STREAM_HARD_TIMEOUT_TICKS,
|
||||
run_id_str,
|
||||
)
|
||||
return
|
||||
|
||||
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
|
||||
# ``node_changed`` message with both fields ``None`` on every redis
|
||||
# timeout. Real ``workflow_completed`` messages keep their kind even
|
||||
# when status couldn't be resolved (publisher race), so checking kind
|
||||
# first makes the heartbeat branch safe.
|
||||
if message.kind == "node_changed" and message.node_id is None and message.status is None:
|
||||
ticks_since_heartbeat += 1
|
||||
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
|
||||
yield ":keepalive\n\n"
|
||||
ticks_since_heartbeat = 0
|
||||
continue
|
||||
ticks_since_heartbeat = 0
|
||||
|
||||
if message.kind == "workflow_completed":
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# node_changed: recompute the node slice from DB
|
||||
if not message.node_id:
|
||||
continue
|
||||
try:
|
||||
node_view = service.node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=run_id_str,
|
||||
node_id=message.node_id,
|
||||
)
|
||||
except NodeOutputInspectorError:
|
||||
# Node may not appear in the graph yet (race with persistence); skip.
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Inspector SSE: node_detail failed for run %s node %s",
|
||||
run_id_str,
|
||||
message.node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"error",
|
||||
{"node_id": message.node_id, "message": "failed to refresh node detail"},
|
||||
event_id,
|
||||
)
|
||||
continue
|
||||
|
||||
event_id += 1
|
||||
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowDraftRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a draft run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_draft_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Published-run endpoints — symmetric to the draft trio above
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowPublishedRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot for a *published* workflow run.
|
||||
|
||||
Same response shape as the ``/draft/`` variant — frontend can multiplex
|
||||
based on which page (Composer test-run vs. Run History) is mounted.
|
||||
"""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status (published run)."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
|
||||
"/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output of a published run."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output of a published run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a published run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_published_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -1,16 +1,38 @@
|
||||
"""Controller decorators for console app resources.
|
||||
|
||||
`with_session` opens one SQLAlchemy session for a request handler and injects it
|
||||
as the first argument after `self`. Handlers use a transaction by default so
|
||||
migrated write paths keep commit/rollback handling; pure read handlers may opt
|
||||
out with `write=False`. App-loading decorators prefer that injected session when
|
||||
present, while still supporting existing handlers that have not been migrated
|
||||
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import overload
|
||||
from typing import Concatenate, cast, overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
def _load_app_model(session: Session, app_id: str) -> App | None:
|
||||
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
|
||||
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
@ -23,6 +45,63 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R],
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
|
||||
|
||||
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R] | None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> (
|
||||
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
|
||||
):
|
||||
"""Inject a request-scoped session, using a transaction only for write handlers."""
|
||||
|
||||
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if write:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
|
||||
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
candidate = args[1]
|
||||
if isinstance(candidate, Session):
|
||||
return candidate
|
||||
|
||||
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
|
||||
return cast(Session, candidate)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
@ -44,6 +123,13 @@ def get_app_model[**P, R](
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Inject the App model for handlers that receive an `app_id` path parameter.
|
||||
|
||||
New handlers may compose `@with_session` above this decorator so the app row
|
||||
is loaded through the same request-scoped session used by the controller.
|
||||
Existing handlers continue to work through `db.session` until migrated.
|
||||
"""
|
||||
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -55,7 +141,11 @@ def get_app_model[**P, R](
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model(app_id)
|
||||
session = _get_injected_session(args)
|
||||
if session is None:
|
||||
app_model = _load_app_model_from_scoped_session(app_id)
|
||||
else:
|
||||
app_model = _load_app_model(session, app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
@ -89,10 +89,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
def delete(self, binding_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -8,9 +8,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
@ -133,12 +133,10 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
|
||||
user_account_id = current_user.id
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@ -48,7 +48,6 @@ class NotionEstimatePayload(BaseModel):
|
||||
class DataSourceNotionListQuery(BaseModel):
|
||||
dataset_id: str | None = Field(default=None, description="Dataset ID")
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
|
||||
|
||||
|
||||
class DataSourceNotionPreviewQuery(BaseModel):
|
||||
@ -205,9 +204,6 @@ class DataSourceNotionListApi(Resource):
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters = query.datasource_parameters or {}
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -255,7 +251,7 @@ class DataSourceNotionListApi(Resource):
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -10,7 +10,12 @@ from controllers.common.fields import UsageCountResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
@ -126,9 +131,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_tenant_id
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
def get(self, current_tenant_id: str):
|
||||
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -40,8 +41,10 @@ register_schema_model(console_ns, TextToAudioPayload)
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
@ -81,8 +84,10 @@ class ChatAudioApi(InstalledAppResource):
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -83,8 +83,10 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -133,8 +135,10 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -157,8 +161,10 @@ class CompletionStopApi(InstalledAppResource):
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -209,8 +215,10 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app, task_id: str):
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -8,6 +8,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -20,7 +21,7 @@ from fields.conversation_fields import (
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -44,8 +45,10 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -92,8 +95,10 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app, c_id: UUID):
|
||||
def delete(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -115,8 +120,10 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app, c_id: UUID):
|
||||
def post(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -146,8 +153,10 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -170,8 +179,10 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app):
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
def patch(self, installed_app: InstalledApp):
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
|
||||
@ -10,6 +10,7 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -21,15 +22,16 @@ from controllers.console.explore.error import (
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -59,9 +61,11 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -96,9 +100,11 @@ class MessageListApi(InstalledAppResource):
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -124,9 +130,11 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -170,9 +178,11 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -7,11 +7,14 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.model import InstalledApp
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -22,9 +25,11 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -46,9 +51,11 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -67,9 +74,11 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
def delete(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
|
||||
|
||||
class CodeBasedExtensionQuery(BaseModel):
|
||||
@ -116,11 +116,11 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@ -130,9 +130,9 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -153,12 +153,12 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@ -169,9 +169,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, id: UUID):
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
@ -197,9 +197,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: UUID):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
|
||||
@ -2,11 +2,11 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from libs.login import current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
@ -24,10 +24,9 @@ class FeatureApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
@ -49,10 +48,9 @@ class FeatureVectorSpaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
||||
@ -22,10 +22,13 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
@ -62,8 +65,8 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
@ -107,10 +110,10 @@ class FilePreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, file_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -8,8 +8,14 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
@ -70,11 +76,10 @@ class NotificationApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
def get(self, current_user: Account):
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
@ -113,11 +118,11 @@ class NotificationDismissApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account):
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@ -12,11 +12,13 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -49,7 +51,8 @@ class RemoteFileUpload(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
|
||||
@login_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
|
||||
@ -74,12 +77,11 @@ class RemoteFileUpload(Resource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
user=current_user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
|
||||
@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
@ -92,8 +99,8 @@ class TagListApi(Resource):
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@ -109,9 +116,9 @@ class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Allow users with edit permission, or dataset editors (including dataset operators).
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
def _require_tag_binding_edit_permission(current_user: Account) -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _create_tag_bindings(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _remove_tag_bindings(current_user)
|
||||
|
||||
@ -77,7 +77,7 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
|
||||
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
|
||||
|
||||
|
||||
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
@ -113,7 +113,7 @@ def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
|
||||
if new_member_count <= 0:
|
||||
return
|
||||
|
||||
features = FeatureService.get_features(tenant_id=tenant_id)
|
||||
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
workspace_members = features.workspace_members
|
||||
|
||||
@ -166,10 +166,10 @@ class TenantListApi(Resource):
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@ -16,6 +17,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.encryption import FieldEncryption
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
@ -94,21 +96,28 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(current_tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403,
|
||||
"The capacity of the knowledge storage space has reached the limit of your subscription.",
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -138,7 +147,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
@ -289,7 +298,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
@ -303,7 +312,6 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object() # type: ignore
|
||||
if not isinstance(user, Account):
|
||||
@ -321,7 +329,6 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object()
|
||||
if not isinstance(user, Account) or not user.is_admin_or_owner:
|
||||
@ -489,3 +496,25 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_tenant_id[T, **P, R](
|
||||
view: Callable[Concatenate[T, str, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return view(self, current_tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
128
api/controllers/openapi/__init__.py
Normal file
128
api/controllers/openapi/__init__.py
Normal file
@ -0,0 +1,128 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
ServerVersionResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspacePayload,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
WorkflowRunData,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
AccountResponse,
|
||||
SessionRow,
|
||||
SessionListResponse,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
FileResponse,
|
||||
ServerVersionResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
_meta,
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
files,
|
||||
human_input_form,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_events,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
# Request models are imported from _models.py and registered above.
|
||||
|
||||
__all__ = [
|
||||
"_meta",
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
"files",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_events",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
66
api/controllers/openapi/_audit.py
Normal file
66
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||
|
||||
|
||||
def emit_app_run(
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
caller_kind: str,
|
||||
mode: str,
|
||||
surface: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
surface,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
"surface": surface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_wrong_surface(
|
||||
*,
|
||||
subject_type: str | None,
|
||||
attempted_path: str,
|
||||
client_id: str | None,
|
||||
token_id: str | None,
|
||||
) -> None:
|
||||
logger.warning(
|
||||
"audit: %s subject_type=%s attempted_path=%s",
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
subject_type,
|
||||
attempted_path,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
"subject_type": subject_type,
|
||||
"attempted_path": attempted_path,
|
||||
"client_id": client_id,
|
||||
"token_id": token_id,
|
||||
},
|
||||
)
|
||||
143
api/controllers/openapi/_input_schema.py
Normal file
143
api/controllers/openapi/_input_schema.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out: dict[str, Any] = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
23
api/controllers/openapi/_meta.py
Normal file
23
api/controllers/openapi/_meta.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
|
||||
|
||||
Returns the server's project version and edition so the difyctl CLI can probe
|
||||
compatibility without needing to be logged in. Mirrors the `_health` endpoint
|
||||
in `index.py`.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import ServerVersionResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_version")
|
||||
class VersionApi(Resource):
|
||||
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||
def get(self):
|
||||
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||
return ServerVersionResponse(
|
||||
version=dify_config.project.version,
|
||||
edition=edition,
|
||||
).model_dump(mode="json")
|
||||
344
api/controllers/openapi/_models.py
Normal file
344
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,344 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class PermittedExternalAppsListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
class AccountPayload(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspacePayload(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
subject_type: str
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account: AccountPayload | None = None
|
||||
workspaces: list[WorkspacePayload] = []
|
||||
default_workspace_id: str | None = None
|
||||
|
||||
|
||||
class SessionRow(BaseModel):
|
||||
id: str
|
||||
prefix: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
created_at: str | None = None
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WorkspaceSummaryResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
workspaces: list[WorkspaceSummaryResponse]
|
||||
|
||||
|
||||
class WorkspaceDetailResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class DeviceCodeResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceLookupResponse(BaseModel):
|
||||
valid: bool
|
||||
expires_in_remaining: int = 0
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class DeviceMutateResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ServerVersionResponse(BaseModel):
|
||||
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
|
||||
|
||||
version: str
|
||||
edition: Literal["SELF_HOSTED", "CLOUD"]
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
workspace_id: UUIDStrOrEmpty | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class PermittedExternalAppsListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
|
||||
|
||||
|
||||
class ExtSubjectAssertionClaims(BaseModel):
|
||||
email: str = _EMAIL_FIELD
|
||||
issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
class ApprovalGrantClaimsPayload(BaseModel):
|
||||
subject_email: str = _EMAIL_FIELD
|
||||
subject_issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
169
api/controllers/openapi/account.py
Normal file
169
api/controllers/openapi/account.py
Normal file
@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.oauth_device_flow import (
|
||||
list_active_sessions,
|
||||
revoke_oauth_token,
|
||||
token_belongs_to_subject,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email,
|
||||
subject_issuer=ctx.subject_issuer,
|
||||
account=None,
|
||||
workspaces=[],
|
||||
default_workspace_id=None,
|
||||
).model_dump(mode="json")
|
||||
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
|
||||
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type=ctx.subject_type,
|
||||
subject_email=ctx.subject_email or (account.email if account else None),
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = list_active_sessions(db.session, ctx, now)
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
SessionRow(
|
||||
id=str(r.id),
|
||||
prefix=r.prefix,
|
||||
client_id=r.client_id,
|
||||
device_label=r.device_label,
|
||||
created_at=_iso(r.created_at),
|
||||
last_used_at=_iso(r.last_used_at),
|
||||
expires_at=_iso(r.expires_at),
|
||||
)
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
if not token_belongs_to_subject(db.session, session_id, ctx):
|
||||
raise NotFound("session not found")
|
||||
|
||||
revoke_oauth_token(db.session, redis_client, session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> WorkspacePayload:
|
||||
join, tenant = row
|
||||
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||
|
||||
|
||||
def _account_payload(account) -> AccountPayload:
|
||||
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||
165
api/controllers/openapi/app_run.py
Normal file
165
api/controllers/openapi/app_run.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
270
api/controllers/openapi/apps.py
Normal file
270
api/controllers/openapi/apps.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → publishes the auth ContextVar before `require_scope`
|
||||
reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
accept_subjects(SubjectType.ACCOUNT),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
|
||||
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
def get(self):
|
||||
ctx: AuthContext = get_auth_ctx()
|
||||
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name = None
|
||||
if pagination.items:
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
|
||||
env = AppListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=cast(int, pagination.total),
|
||||
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
102
api/controllers/openapi/apps_permitted_external.py
Normal file
102
api/controllers/openapi/apps_permitted_external.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||
|
||||
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||
(public / sso_verified). License-gated: CE deploys never enable the
|
||||
EE blueprint chain so this module is unreachable there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
from services.openapi.license_gate import license_required
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
|
||||
license_required,
|
||||
accept_subjects(SubjectType.EXTERNAL_SSO),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
def get(self):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
}
|
||||
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
|
||||
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=page_result.total,
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
46
api/controllers/openapi/auth/composition.py
Normal file
46
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
SurfaceCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
68
api/controllers/openapi/auth/context.py
Normal file
68
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
|
||||
Context is intentionally decoupled from Flask's ``Request``: the pipeline
|
||||
guard extracts whatever transport-level inputs the steps need (bearer
|
||||
token, path params) at the boundary and writes them into Context fields,
|
||||
so steps stay testable without a request object and won't leak coupling
|
||||
to a specific framework.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
required_scope: Scope
|
||||
bearer_token: str | None = None
|
||||
path_params: Mapping[str, str] = field(default_factory=dict)
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
auth_ctx_reset_token: Token[AuthContext] | None = None
|
||||
|
||||
@property
|
||||
def must_tenant(self) -> Tenant:
|
||||
if not self.tenant:
|
||||
raise Unauthorized("tenant is not associated")
|
||||
return self.tenant
|
||||
|
||||
@property
|
||||
def must_subject_type(self) -> SubjectType:
|
||||
if not self.subject_type:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
return self.subject_type
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
51
api/controllers/openapi/auth/pipeline.py
Normal file
51
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
# Extract transport-level inputs at the boundary so steps
|
||||
# stay decoupled from Flask's request object.
|
||||
ctx = Context(
|
||||
required_scope=scope,
|
||||
bearer_token=extract_bearer(request),
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
try:
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
if ctx.auth_ctx_reset_token is not None:
|
||||
reset_auth_ctx(ctx.auth_ctx_reset_token)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
170
api/controllers/openapi/auth/steps.py
Normal file
170
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`. `BearerCheck` also publishes the
|
||||
resolved identity to the openapi auth ``ContextVar`` (the same one the
|
||||
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
|
||||
surface gate and any handler reading the request-scoped context has a single
|
||||
source of truth across both auth-attach paths. The reset token is stashed
|
||||
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
|
||||
its `finally` so worker-thread reuse can't leak identity across requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from controllers.openapi.auth.surface_gate import check_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from models import TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here.
|
||||
Also publishes the resolved `AuthContext` via
|
||||
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
|
||||
``validate_bearer`` writes — so the surface gate + downstream readers
|
||||
don't see two different identity sources. The reset token is parked on
|
||||
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not ctx.bearer_token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(ctx.bearer_token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class SurfaceCheck:
|
||||
"""Reject the request if the resolved subject is not in `accepted`."""
|
||||
|
||||
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
|
||||
self._accepted = accepted
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
check_surface(self._accepted)
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
|
||||
``ctx.path_params`` at the boundary so this step doesn't need to know
|
||||
about the request object.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = ctx.path_params.get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.must_tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.must_subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AppAuthzCheck",
|
||||
"AppResolver",
|
||||
"AuthContext",
|
||||
"BearerCheck",
|
||||
"CallerMount",
|
||||
"ScopeCheck",
|
||||
"SurfaceCheck",
|
||||
"WorkspaceMembershipCheck",
|
||||
]
|
||||
168
api/controllers/openapi/auth/strategies.py
Normal file
168
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import (
|
||||
EnterpriseService,
|
||||
WebAppAccessMode,
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL, evaluated in two stages.
|
||||
|
||||
The EE gateway has already enforced tenancy and workspace membership
|
||||
by the time this strategy runs, so AclStrategy only owns per-app ACL:
|
||||
|
||||
1. Subject vs access-mode compatibility (pure rule table). External-SSO
|
||||
bearers belong to public-facing apps only; account bearers cover the
|
||||
full set. A mismatch is an immediate deny — no IO.
|
||||
2. For modes that pair with the subject, decide whether the inner
|
||||
permission API must run. Only `PRIVATE` (per-app selected-user list)
|
||||
requires it; the remaining modes are pass-through.
|
||||
"""
|
||||
|
||||
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
|
||||
SubjectType.ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
SubjectType.EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
access_mode = self._fetch_access_mode(ctx.app.id)
|
||||
if access_mode is None:
|
||||
return False
|
||||
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
|
||||
return False
|
||||
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
|
||||
return True
|
||||
return self._inner_permission_check(ctx)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if settings is None:
|
||||
return None
|
||||
try:
|
||||
return WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
|
||||
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
|
||||
|
||||
def _inner_permission_check(self, ctx: Context) -> bool:
|
||||
if ctx.app is None:
|
||||
return False
|
||||
user_id = self._resolve_user_id(ctx)
|
||||
if user_id is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_id(ctx: Context) -> str | None:
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return str(ctx.account_id) if ctx.account_id is not None else None
|
||||
if ctx.subject_email is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
|
||||
return str(account.id) if account is not None else None
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user) # type:ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.must_tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
89
api/controllers/openapi/auth/surface_gate.py
Normal file
89
api/controllers/openapi/auth/surface_gate.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that the resolved subject is in ``accepted``.
|
||||
|
||||
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
|
||||
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
|
||||
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
|
||||
set the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
72
api/controllers/openapi/files.py
Normal file
72
api/controllers/openapi/files.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from models import Account, App
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/files/upload")
|
||||
class AppFileUploadApi(Resource):
|
||||
@openapi_ns.doc("upload_file_for_app_input")
|
||||
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
|
||||
@openapi_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request — no file or filename missing",
|
||||
401: "Unauthorized — invalid or expired bearer token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type or blocked extension",
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=caller,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except services.errors.file.FileTooLargeError as exc:
|
||||
raise FileTooLargeError(exc.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError(exc.description)
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
107
api/controllers/openapi/human_input_form.py
Normal file
107
api/controllers/openapi/human_input_form.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""
|
||||
OpenAPI bearer-authed human input form endpoints.
|
||||
|
||||
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
|
||||
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App
|
||||
from services.human_input_service import FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
if caller_kind == "account":
|
||||
submission_user_id = caller.id
|
||||
else:
|
||||
submission_end_user_id = caller.id
|
||||
|
||||
if form.recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_token=%s", form_token)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
9
api/controllers/openapi/index.py
Normal file
9
api/controllers/openapi/index.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
398
api/controllers/openapi/oauth_device.py
Normal file
398
api/controllers/openapi/oauth_device.py
Normal file
@ -0,0 +1,398 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Validation helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
365
api/controllers/openapi/oauth_device_sso.py
Normal file
365
api/controllers/openapi/oauth_device_sso.py
Normal file
@ -0,0 +1,365 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import bp
|
||||
from controllers.openapi._models import ExtSubjectAssertionClaims
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
def _trusted_origin() -> str:
|
||||
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
|
||||
if not base:
|
||||
raise BadGateway("console_api_url_unset")
|
||||
return base
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
origin = _trusted_origin()
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
try:
|
||||
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
|
||||
except ValidationError as e:
|
||||
logger.warning("sso-complete: claim shape invalid: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = claims.user_code.strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.email):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = _trusted_origin()
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims.email,
|
||||
subject_issuer=claims.issuer,
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
# SSO branch of the shared PollPayload contract: account/workspace
|
||||
# fields are zero-filled (`None` / `[]`) for parity with the account
|
||||
# branch in `oauth_device._build_account_poll_payload`.
|
||||
poll_payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RejectedClaims:
|
||||
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||
|
||||
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||
event without reaching for the full dataclass.
|
||||
"""
|
||||
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
reason,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_rejected",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"reason": reason,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
119
api/controllers/openapi/workflow_events.py
Normal file
119
api/controllers/openapi/workflow_events.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
OpenAPI bearer-authed workflow reconnect event stream endpoint.
|
||||
|
||||
GET /apps/<app_id>/tasks/<task_id>/events
|
||||
— reconnect to the SSE stream for a paused/running workflow run.
|
||||
`task_id` is treated as `workflow_run_id`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if caller_kind == "account":
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
else:
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=caller,
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
else:
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.OPENAPI,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
78
api/controllers/openapi/workspaces.py
Normal file
78
api/controllers/openapi/workspaces.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
get_auth_ctx,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
)
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||
return WorkspaceDetailResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
)
|
||||
@ -13,6 +13,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -140,20 +141,26 @@ def cloud_edition_billing_resource_check[**P, R](
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(api_token.tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
@ -174,7 +181,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
@ -56,7 +56,7 @@ class AppParameterApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app_model.workflow
|
||||
@ -92,7 +92,7 @@ class AppMeta(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Get app meta"""
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from core.errors.error import (
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -86,7 +86,7 @@ class CompletionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -169,7 +169,7 @@ class ChatApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, app_model, end_user, task_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.conversation_fields import (
|
||||
SimpleConversation,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -81,7 +81,7 @@ class ConversationListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -127,7 +127,7 @@ class ConversationApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, c_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -166,7 +166,7 @@ class ConversationRenameApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model, end_user, c_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -204,7 +204,7 @@ class ConversationPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -235,7 +235,7 @@ class ConversationUnPinApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
|
||||
def patch(self, app_model, end_user, c_id: UUID):
|
||||
def patch(self, app_model: App, end_user: EndUser, c_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(web_ns, FileResponse)
|
||||
@ -31,7 +32,7 @@ class FileApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(201, "File uploaded successfully", web_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file for use in web applications.
|
||||
|
||||
Accepts file uploads for use within web applications, supporting
|
||||
|
||||
@ -27,7 +27,7 @@ from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfinite
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -81,7 +81,7 @@ class MessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -133,7 +133,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__])
|
||||
def post(self, app_model, end_user, message_id: UUID):
|
||||
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
|
||||
@ -167,7 +167,7 @@ class MessageMoreLikeThisApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -223,7 +223,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user, message_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -13,6 +13,7 @@ from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
from ..common.schema import register_response_schema_models, register_schema_models
|
||||
@ -41,7 +42,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Remote file info", web_ns.models[RemoteFileInfo.__name__])
|
||||
def get(self, app_model, end_user, url):
|
||||
def get(self, app_model: App, end_user: EndUser, url: str):
|
||||
"""Get information about a remote file.
|
||||
|
||||
Retrieves basic information about a file located at a remote URL,
|
||||
@ -85,7 +86,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(201, "Remote file uploaded", web_ns.models[FileWithSignedUrl.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Upload a file from a remote URL.
|
||||
|
||||
Downloads a file from the provided remote URL and uploads it
|
||||
|
||||
@ -11,6 +11,7 @@ from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models.model import App, EndUser
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -43,7 +44,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -77,7 +78,7 @@ class SavedMessageListApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@web_ns.response(200, "Message saved successfully", web_ns.models[ResultResponse.__name__])
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -106,7 +107,7 @@ class SavedMessageApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def delete(self, app_model, end_user, message_id: UUID):
|
||||
def delete(self, app_model: App, end_user: EndUser, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
if app_model.mode != "completion":
|
||||
|
||||
@ -10,7 +10,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.model import App, EndUser, Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class AppSiteApi(WebApiResource):
|
||||
}
|
||||
)
|
||||
@marshal_with(app_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
"""Retrieve app site info."""
|
||||
# get site
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -78,10 +78,10 @@ class AppSiteApi(WebApiResource):
|
||||
if not site:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo
|
||||
|
||||
return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo)
|
||||
|
||||
@ -119,6 +119,6 @@ def serialize_site(site: Site) -> dict[str, Any]:
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id, exclude_vector_space=True).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@ -16,7 +16,7 @@ from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
@ -74,7 +74,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
app_web_auth_enabled = webapp_settings.access_mode != WebAppAccessMode.PUBLIC
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
@ -88,7 +88,8 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode
|
||||
!= WebAppAccessMode.PUBLIC
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
@ -22,9 +22,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
@ -150,44 +147,9 @@ class BaseAgentRunner(AppRunner):
|
||||
message_tool = PromptMessageTool(
|
||||
name=tool.tool_name,
|
||||
description=tool_entity.entity.description.llm,
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
parameters=tool_entity.get_llm_parameters_json_schema(),
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
@ -252,40 +214,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"""
|
||||
update prompt message tool
|
||||
"""
|
||||
# try to get tool runtime parameters
|
||||
tool_runtime_parameters = tool.get_runtime_parameters()
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
continue
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
prompt_tool.parameters = tool.get_llm_parameters_json_schema()
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(
|
||||
|
||||
@ -198,7 +198,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=(
|
||||
args.get("parent_message_id")
|
||||
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||
else UUID_NIL
|
||||
),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -167,7 +167,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=(
|
||||
args.get("parent_message_id")
|
||||
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||
else UUID_NIL
|
||||
),
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
EasyUIBasedAppGenerateEntity,
|
||||
@ -292,46 +293,51 @@ class AppRunner:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
try:
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if isinstance(content, str):
|
||||
text += content
|
||||
elif isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
if message_id and user_id and tenant_id:
|
||||
try:
|
||||
self._handle_multimodal_image_content(
|
||||
content=content,
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
queue_manager=queue_manager,
|
||||
)
|
||||
except Exception:
|
||||
_logger.exception("Failed to handle multimodal image output")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
_logger.warning("Received multimodal output but missing required parameters")
|
||||
else:
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
text += content.data if hasattr(content, "data") else str(content)
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
except GenerateTaskStoppedError:
|
||||
# Explicitly close provider stream to stop in-flight token generation ASAP.
|
||||
invoke_result.close()
|
||||
raise
|
||||
|
||||
if usage is None:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
@ -161,7 +161,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=(
|
||||
args.get("parent_message_id")
|
||||
if invoke_from not in {InvokeFrom.SERVICE_API, InvokeFrom.OPENAPI}
|
||||
else UUID_NIL
|
||||
),
|
||||
user_id=user.id,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
|
||||
@ -53,6 +53,14 @@ from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
|
||||
# Maps the entry surface a workflow was invoked from to the HITL surface that
|
||||
# its resume tokens must be filtered for. Surfaces not in this map fall back to
|
||||
# the general priority ordering (typically CONSOLE > BACKSTAGE).
|
||||
_INVOKE_FROM_TO_HITL_SURFACE: Mapping[InvokeFrom, HumanInputSurface] = {
|
||||
InvokeFrom.SERVICE_API: HumanInputSurface.SERVICE_API,
|
||||
InvokeFrom.OPENAPI: HumanInputSurface.OPENAPI,
|
||||
}
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -340,11 +348,7 @@ class WorkflowResponseConverter:
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
surface=_INVOKE_FROM_TO_HITL_SURFACE.get(self._application_generate_entity.invoke_from),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
|
||||
@ -731,6 +731,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
@ -42,6 +43,7 @@ class InvokeFrom(StrEnum):
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
InvokeFrom.OPENAPI: "openapi",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
@ -47,6 +47,12 @@ from graphon.graph_events import (
|
||||
)
|
||||
from graphon.node_events import NodeRunResult
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from services.workflow.inspector_events import (
|
||||
publish_node_changed as _inspector_publish_node_changed,
|
||||
)
|
||||
from services.workflow.inspector_events import (
|
||||
publish_workflow_completed as _inspector_publish_workflow_completed,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -163,6 +169,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -173,6 +180,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -184,6 +192,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._fail_running_node_executions(error_message=event.error)
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -194,6 +203,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._fail_running_node_executions(error_message=execution.error_message or "")
|
||||
self._workflow_execution_repository.save(execution)
|
||||
self._enqueue_trace_task(execution)
|
||||
_inspector_publish_workflow_completed(workflow_run_id=execution.id_, status=str(execution.status.value))
|
||||
|
||||
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
|
||||
execution = self._get_workflow_execution()
|
||||
@ -241,6 +251,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
_inspector_publish_node_changed(workflow_run_id=execution.id_, node_id=event.node_id, status="running")
|
||||
|
||||
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -248,6 +259,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
domain_execution.error = event.error
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="retry",
|
||||
)
|
||||
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -257,6 +273,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="succeeded",
|
||||
)
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -267,6 +288,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@ -277,6 +303,11 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
error=event.error,
|
||||
finished_at=event.finished_at,
|
||||
)
|
||||
_inspector_publish_node_changed(
|
||||
workflow_run_id=self._get_workflow_execution().id_,
|
||||
node_id=domain_execution.node_id,
|
||||
status="exception",
|
||||
)
|
||||
|
||||
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
|
||||
@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any, cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
@ -53,6 +54,9 @@ else:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_DAEMON_MAX_PATH_LENGTH = 4096
|
||||
PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH = 8
|
||||
|
||||
_httpx_client: httpx.Client = get_pooled_http_client(
|
||||
"plugin_daemon",
|
||||
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
|
||||
@ -103,6 +107,20 @@ class BasePluginClient:
|
||||
params: dict[str, Any] | None,
|
||||
files: dict[str, Any] | None,
|
||||
) -> tuple[str, dict[str, str], bytes | dict[str, Any] | str | None, dict[str, Any] | None, dict[str, Any] | None]:
|
||||
if len(path) > PLUGIN_DAEMON_MAX_PATH_LENGTH:
|
||||
raise ValueError(f"Invalid plugin daemon path: path length exceeds {PLUGIN_DAEMON_MAX_PATH_LENGTH}")
|
||||
|
||||
decoded_path = path
|
||||
for _ in range(PLUGIN_DAEMON_MAX_PATH_DECODE_DEPTH):
|
||||
next_decoded_path = unquote(decoded_path)
|
||||
if next_decoded_path == decoded_path:
|
||||
break
|
||||
decoded_path = next_decoded_path
|
||||
else:
|
||||
raise ValueError("Invalid plugin daemon path: path is too deeply encoded")
|
||||
|
||||
if any(seg == ".." for seg in decoded_path.split("/")):
|
||||
raise ValueError(f"Invalid plugin daemon path: traversal sequence detected in {path!r}")
|
||||
url = plugin_daemon_inner_api_baseurl / path
|
||||
prepared_headers = dict(headers or {})
|
||||
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
|
||||
|
||||
@ -534,7 +534,9 @@ class ProviderManager:
|
||||
cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled"
|
||||
cache_result = redis_client.get(cache_key)
|
||||
if cache_result is None:
|
||||
model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled
|
||||
model_load_balancing_enabled = FeatureService.get_features(
|
||||
tenant_id, exclude_vector_space=True
|
||||
).model_load_balancing_enabled
|
||||
redis_client.setex(cache_key, 120, str(model_load_balancing_enabled))
|
||||
else:
|
||||
cache_result = cache_result.decode("utf-8")
|
||||
|
||||
@ -126,34 +126,89 @@ class Tool(ABC):
|
||||
message_id: str | None = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get merged runtime parameters
|
||||
Get the effective parameter declarations for this tool.
|
||||
|
||||
Runtime parameters override declared parameters by name and append new
|
||||
parameters, but the returned list is always detached from the tool's
|
||||
cached declarations so callers can safely mutate it while building
|
||||
downstream schemas.
|
||||
|
||||
:return: merged runtime parameters
|
||||
"""
|
||||
parameters = self.entity.parameters
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
parameters = [deepcopy(parameter) for parameter in self.entity.parameters or []]
|
||||
user_parameters = [
|
||||
deepcopy(parameter)
|
||||
for parameter in self.get_runtime_parameters(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
or []
|
||||
]
|
||||
|
||||
parameter_indexes = {parameter.name: index for index, parameter in enumerate(parameters)}
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
break
|
||||
else:
|
||||
# add new parameter
|
||||
existing_index = parameter_indexes.get(parameter.name)
|
||||
if existing_index is None:
|
||||
parameter_indexes[parameter.name] = len(parameters)
|
||||
parameters.append(parameter)
|
||||
continue
|
||||
parameters[existing_index] = parameter
|
||||
|
||||
return parameters
|
||||
|
||||
def get_llm_parameters_json_schema(
|
||||
self,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the model-visible JSON schema from effective tool parameters.
|
||||
|
||||
Hidden/manual parameters stay available for invocation preparation on the
|
||||
API side, but are intentionally omitted from the LLM-facing schema.
|
||||
"""
|
||||
schema: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
for parameter in self.get_merged_runtime_parameters(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
continue
|
||||
|
||||
parameter_schema: dict[str, Any] = (
|
||||
{
|
||||
"type": parameter.type.as_normal_type(),
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else deepcopy(parameter.input_schema)
|
||||
)
|
||||
parameter_schema.setdefault("description", parameter.llm_description or "")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT and parameter.options:
|
||||
parameter_schema["enum"] = [option.value for option in parameter.options]
|
||||
|
||||
schema["properties"][parameter.name] = parameter_schema
|
||||
if parameter.required:
|
||||
schema["required"].append(parameter.name)
|
||||
|
||||
return schema
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
||||
@ -63,7 +63,7 @@ def _get_surface_form_token(
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
if surface in {HumanInputSurface.SERVICE_API, HumanInputSurface.OPENAPI}:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
@ -11,13 +11,15 @@ from models.human_input import RecipientType
|
||||
class HumanInputSurface(StrEnum):
|
||||
SERVICE_API = "service_api"
|
||||
CONSOLE = "console"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
|
||||
# Service API is intentionally narrower than other surfaces: app-token callers
|
||||
# SERVICE_API and OPENAPI are intentionally narrower than CONSOLE: token callers
|
||||
# should only be able to act on end-user web forms, not internal console flows.
|
||||
_ALLOWED_RECIPIENT_TYPES_BY_SURFACE: dict[HumanInputSurface, frozenset[RecipientType]] = {
|
||||
HumanInputSurface.SERVICE_API: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
HumanInputSurface.CONSOLE: frozenset({RecipientType.CONSOLE, RecipientType.BACKSTAGE}),
|
||||
HumanInputSurface.OPENAPI: frozenset({RecipientType.STANDALONE_WEB_APP}),
|
||||
}
|
||||
|
||||
# A single HITL form can have multiple recipient records; this shared priority
|
||||
|
||||
@ -473,11 +473,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
from clients.agent_backend import AgentBackendRunEventAdapter, AgentBackendRunRequestBuilder
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from core.workflow.nodes.agent_v2.file_tenant_validator import UploadFileTenantValidator
|
||||
from core.workflow.nodes.agent_v2.output_check_executor import FileOutputCheckExecutor
|
||||
from core.workflow.nodes.agent_v2.output_check_model_invoker import ModelRuntimeOutputCheckInvoker
|
||||
from core.workflow.nodes.agent_v2.output_failure_orchestrator import OutputFailureOrchestrator
|
||||
from core.workflow.nodes.agent_v2.output_type_checker import PerOutputTypeChecker
|
||||
from core.workflow.nodes.agent_v2.upload_file_content_loader import UploadFileContentLoader
|
||||
|
||||
return {
|
||||
"binding_resolver": WorkflowAgentBindingResolver(),
|
||||
@ -492,15 +489,10 @@ class DifyNodeFactory(NodeFactory):
|
||||
),
|
||||
"event_adapter": AgentBackendRunEventAdapter(),
|
||||
"output_adapter": WorkflowAgentOutputAdapter(),
|
||||
# Stage 4 §5/§6/§7: per-output validation + benchmark check +
|
||||
# failure orchestration. The tenant validator and content
|
||||
# loader query upload_files lazily so they stay cheap when
|
||||
# declared outputs include no file refs.
|
||||
# Stage 4 §5/§7: per-output validation + failure orchestration. The
|
||||
# tenant validator queries upload_files so it stays cheap when
|
||||
# outputs contain no file refs.
|
||||
"type_checker": PerOutputTypeChecker(file_validator=UploadFileTenantValidator()),
|
||||
"output_check_executor": FileOutputCheckExecutor(
|
||||
content_loader=UploadFileContentLoader(),
|
||||
model_invoker=ModelRuntimeOutputCheckInvoker(),
|
||||
),
|
||||
"failure_orchestrator": OutputFailureOrchestrator(),
|
||||
}
|
||||
return {
|
||||
|
||||
@ -23,12 +23,11 @@ from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from models.agent_config_entities import AgentSoulConfig, AgentSoulModelConfig, WorkflowNodeJobConfig
|
||||
from models.agent_config_entities import WorkflowNodeJobConfig
|
||||
|
||||
from .binding_resolver import WorkflowAgentBindingError, WorkflowAgentBindingResolver
|
||||
from .entities import DifyAgentNodeData
|
||||
from .output_adapter import WorkflowAgentOutputAdapter
|
||||
from .output_check_executor import FileOutputCheckExecutor, FileOutputCheckOutcome
|
||||
from .output_failure_orchestrator import (
|
||||
FailedOutput,
|
||||
OutputFailureDecision,
|
||||
@ -74,7 +73,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
event_adapter: AgentBackendRunEventAdapter,
|
||||
output_adapter: WorkflowAgentOutputAdapter,
|
||||
type_checker: PerOutputTypeChecker,
|
||||
output_check_executor: FileOutputCheckExecutor,
|
||||
failure_orchestrator: OutputFailureOrchestrator,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -89,7 +87,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
self._event_adapter = event_adapter
|
||||
self._output_adapter = output_adapter
|
||||
self._type_checker = type_checker
|
||||
self._output_check_executor = output_check_executor
|
||||
self._failure_orchestrator = failure_orchestrator
|
||||
|
||||
@classmethod
|
||||
@ -146,22 +143,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
outputs_by_name = {o.name: o for o in effective_outputs}
|
||||
|
||||
# Stage 4 §6: output check borrows the Agent Soul's model identity for
|
||||
# its evaluator call. ``runtime_request_builder.build`` would also
|
||||
# reject a missing model later, but we surface a deterministic error
|
||||
# here so the failure_event has a sensible code.
|
||||
agent_soul = AgentSoulConfig.model_validate(bundle.snapshot.config_snapshot_dict)
|
||||
if agent_soul.model is None:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error="Workflow Agent node requires Agent Soul model config.",
|
||||
error_type="agent_model_not_configured",
|
||||
)
|
||||
return
|
||||
agent_model: AgentSoulModelConfig = agent_soul.model
|
||||
|
||||
# ──── Retry loop (Stage 4 §7) ────
|
||||
attempt = 0
|
||||
while True:
|
||||
@ -253,7 +234,7 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# ──── Stage 4 §5: per-output type check ────
|
||||
# ──── Stage 4: per-output type check ────
|
||||
type_check = self._type_checker.check(
|
||||
declared_outputs=effective_outputs,
|
||||
raw_output=terminal_event.output,
|
||||
@ -261,38 +242,19 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
self._record_type_check_metadata(metadata, type_check)
|
||||
|
||||
# ──── Stage 4 §6: file benchmark output check ────
|
||||
# Only run when type check passes; comparing content of a value
|
||||
# that's already mis-typed is wasted work and would produce a
|
||||
# confusing second failure for the same root cause.
|
||||
output_check: FileOutputCheckOutcome | None = None
|
||||
if not type_check.has_failures:
|
||||
output_check = self._output_check_executor.check_all(
|
||||
declared_outputs=effective_outputs,
|
||||
raw_output=terminal_event.output,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_provider=agent_model.model_provider,
|
||||
model_name=agent_model.model,
|
||||
model_settings=agent_model.model_settings,
|
||||
)
|
||||
self._record_output_check_metadata(metadata, output_check)
|
||||
|
||||
if not output_check.has_failures:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_success_result(
|
||||
event=terminal_event,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_success_result(
|
||||
event=terminal_event,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
)
|
||||
return
|
||||
)
|
||||
return
|
||||
|
||||
# ──── Stage 4 §7: orchestrate retry / default / fail ────
|
||||
# Aggregate failures from both stages. They are mutually exclusive
|
||||
# (§6 only runs when §5 passes), so the resulting list contains
|
||||
# exclusively TYPE_CHECK or OUTPUT_CHECK failures, never both.
|
||||
failures: list[FailedOutput] = [
|
||||
# ──── Stage 4: orchestrate retry / default / fail ────
|
||||
failures = [
|
||||
FailedOutput(
|
||||
declared=outputs_by_name[result.name],
|
||||
failure_kind=OutputFailureKind.TYPE_CHECK,
|
||||
@ -301,17 +263,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
for result in type_check.failures
|
||||
if result.name in outputs_by_name
|
||||
]
|
||||
if output_check is not None:
|
||||
failures.extend(
|
||||
FailedOutput(
|
||||
declared=outputs_by_name[result.output_name],
|
||||
failure_kind=OutputFailureKind.OUTPUT_CHECK,
|
||||
reason=result.reason,
|
||||
)
|
||||
for result in output_check.failures
|
||||
if result.output_name in outputs_by_name
|
||||
)
|
||||
|
||||
outcome = self._failure_orchestrator.decide(failures=failures, current_attempt=attempt)
|
||||
metadata["output_failure_decision"] = outcome.decision.value
|
||||
metadata["output_failure_reason"] = outcome.primary_reason
|
||||
@ -332,16 +283,10 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Pick an error_type that reflects which stage produced the
|
||||
# surviving failure(s); FAIL_BRANCH gets a suffixed variant so
|
||||
# downstream metrics can tell the two paths apart.
|
||||
base_code = (
|
||||
"output_content_check_failed"
|
||||
if OutputFailureKind.OUTPUT_CHECK in outcome.failure_kinds
|
||||
else "output_type_check_failed"
|
||||
)
|
||||
error_type = (
|
||||
f"{base_code}_fail_branch" if outcome.decision == OutputFailureDecision.TAKE_FAIL_BRANCH else base_code
|
||||
"output_type_check_failed_fail_branch"
|
||||
if outcome.decision == OutputFailureDecision.TAKE_FAIL_BRANCH
|
||||
else "output_type_check_failed"
|
||||
)
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
@ -439,45 +384,6 @@ class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _record_output_check_metadata(metadata: dict[str, Any], outcome: FileOutputCheckOutcome) -> None:
|
||||
"""Persist §6 results into node metadata, including the D-2 usage bucket.
|
||||
|
||||
``output_check_usage`` is keyed by output name so multiple file
|
||||
outputs can share metadata without colliding. The bucket is recorded
|
||||
even on FAILED / SKIPPED results so the billing pipeline observes any
|
||||
model usage the executor did spend.
|
||||
"""
|
||||
if not outcome.results:
|
||||
return
|
||||
per_output_usage: dict[str, dict[str, Any]] = {}
|
||||
result_payload: list[dict[str, Any]] = []
|
||||
for r in outcome.results:
|
||||
result_payload.append(
|
||||
{
|
||||
"name": r.output_name,
|
||||
"status": r.status.value,
|
||||
"reason": r.reason,
|
||||
"skip_reason": r.skip_reason.value if r.skip_reason else None,
|
||||
"content_truncated": r.content_truncated,
|
||||
}
|
||||
)
|
||||
per_output_usage[r.output_name] = {
|
||||
"prompt_tokens": r.usage.prompt_tokens,
|
||||
"completion_tokens": r.usage.completion_tokens,
|
||||
"total_tokens": r.usage.total_tokens,
|
||||
"total_price": str(r.usage.total_price),
|
||||
"currency": r.usage.currency,
|
||||
"latency_ms": r.usage.latency_ms,
|
||||
}
|
||||
metadata["output_check"] = {
|
||||
"passed": not outcome.has_failures,
|
||||
"results": result_payload,
|
||||
}
|
||||
# D-2: keep this bucket separate from ``agent_run_usage`` so billing
|
||||
# can tell agent inference apart from output verification.
|
||||
metadata["output_check_usage"] = per_output_usage
|
||||
|
||||
@staticmethod
|
||||
def _patch_event_with_defaults(
|
||||
event: AgentBackendRunSucceededInternalEvent,
|
||||
|
||||
@ -1,455 +0,0 @@
|
||||
"""Per-output file benchmark check executor for Workflow Agent Node v2.
|
||||
|
||||
Stage 4 §6: after :class:`PerOutputTypeChecker` has confirmed that every
|
||||
declared output is structurally well-formed, this executor runs an *optional*,
|
||||
model-based semantic check on file outputs whose
|
||||
``DeclaredOutputCheckConfig.enabled`` is ``True``. The check is performed:
|
||||
|
||||
* by **directly invoking the configured model** (NOT through the Agent backend),
|
||||
because the backend's ``dify.output`` layer only enforces structural JSON
|
||||
schema and has no notion of "compare two file payloads";
|
||||
* with **its token usage bucketed separately** as ``output_check_usage`` so
|
||||
billing / observability never confuses agent-run usage with output-validation
|
||||
usage (decision D-2);
|
||||
* with **file content loading and model invocation pluggable** via
|
||||
:class:`FileContentLoader` and :class:`OutputCheckModelInvoker` Protocols so
|
||||
unit tests can drive the executor without DB / network access.
|
||||
|
||||
Failures here surface upward as
|
||||
``OutputFailureKind.OUTPUT_CHECK`` and feed the existing
|
||||
:class:`OutputFailureOrchestrator` decision chain alongside type-check
|
||||
failures.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from models.agent_config_entities import DeclaredOutputConfig, DeclaredOutputType
|
||||
|
||||
|
||||
class OutputCheckModelInvocationError(Exception):
|
||||
"""Raised by :class:`OutputCheckModelInvoker` when the LLM call fails.
|
||||
|
||||
The executor catches this and produces a ``SKIPPED`` result tagged with
|
||||
:attr:`FileOutputCheckSkipReason.MODEL_INVOCATION_ERROR` so the surrounding
|
||||
retry / fail-branch logic can still proceed deterministically.
|
||||
"""
|
||||
|
||||
|
||||
class FileOutputCheckStatus(StrEnum):
|
||||
"""Lifecycle status of a single file output after the benchmark check."""
|
||||
|
||||
PASSED = "passed"
|
||||
FAILED = "failed"
|
||||
# Check did not produce a pass/fail signal (unsupported file type, file
|
||||
# not accessible, model error, ...). Skipped checks do NOT feed the
|
||||
# failure orchestrator — they are surfaced as warnings in metadata.
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class FileOutputCheckSkipReason(StrEnum):
|
||||
"""Why an output-check result is :attr:`FileOutputCheckStatus.SKIPPED`."""
|
||||
|
||||
UNSUPPORTED_FILE_FOR_OUTPUT_CHECK = "unsupported_file_for_output_check"
|
||||
BENCHMARK_FILE_NOT_ACCESSIBLE = "benchmark_file_not_accessible"
|
||||
PRODUCED_FILE_MISSING = "produced_file_missing"
|
||||
MODEL_INVOCATION_ERROR = "output_check_model_error"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileOutputCheckUsage:
|
||||
"""Token / cost accounting for one output-check LLM invocation (§6.2 D-2).
|
||||
|
||||
Shape intentionally mirrors ``LLMUsage`` so future code can aggregate
|
||||
multiple per-output usages and serialize them next to (but separate from)
|
||||
agent run usage. Zero-valued instance means "no model call was made".
|
||||
"""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
total_price: Decimal = field(default_factory=lambda: Decimal(0))
|
||||
currency: str = "USD"
|
||||
latency_ms: int = 0
|
||||
|
||||
def __add__(self, other: FileOutputCheckUsage) -> FileOutputCheckUsage:
|
||||
if not isinstance(other, FileOutputCheckUsage):
|
||||
return NotImplemented
|
||||
return FileOutputCheckUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=self.currency if self.total_price else other.currency,
|
||||
latency_ms=self.latency_ms + other.latency_ms,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileOutputCheckResult:
|
||||
"""Outcome of running benchmark check on one declared file output."""
|
||||
|
||||
output_name: str
|
||||
status: FileOutputCheckStatus
|
||||
reason: str
|
||||
usage: FileOutputCheckUsage = field(default_factory=FileOutputCheckUsage)
|
||||
skip_reason: FileOutputCheckSkipReason | None = None
|
||||
content_truncated: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class FileOutputCheckOutcome:
|
||||
"""Aggregate of per-output check results for one agent backend run."""
|
||||
|
||||
results: tuple[FileOutputCheckResult, ...]
|
||||
|
||||
@property
|
||||
def failures(self) -> tuple[FileOutputCheckResult, ...]:
|
||||
return tuple(r for r in self.results if r.status == FileOutputCheckStatus.FAILED)
|
||||
|
||||
@property
|
||||
def has_failures(self) -> bool:
|
||||
return bool(self.failures)
|
||||
|
||||
@property
|
||||
def total_usage(self) -> FileOutputCheckUsage:
|
||||
total = FileOutputCheckUsage()
|
||||
for r in self.results:
|
||||
total = total + r.usage
|
||||
return total
|
||||
|
||||
def by_name(self) -> dict[str, FileOutputCheckResult]:
|
||||
return {r.output_name: r for r in self.results}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LoadedFileContent:
|
||||
"""Output of :class:`FileContentLoader`.
|
||||
|
||||
``text`` is empty when ``unsupported`` is ``True``; callers must check the
|
||||
flag before reading text. ``truncated`` indicates the loader had to drop
|
||||
content to fit the configured budget.
|
||||
"""
|
||||
|
||||
text: str
|
||||
truncated: bool = False
|
||||
unsupported: bool = False
|
||||
|
||||
|
||||
class FileContentLoader(Protocol):
|
||||
"""Resolve a ``file_id`` for the given tenant into model-readable text.
|
||||
|
||||
Returning ``None`` signals the file is missing / cross-tenant / failed to
|
||||
extract — the executor maps that to a SKIPPED result instead of failing
|
||||
the whole node. Implementations must not raise on those cases; raising is
|
||||
reserved for unexpected runtime errors.
|
||||
"""
|
||||
|
||||
def load(self, *, file_id: str, tenant_id: str) -> LoadedFileContent | None: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class OutputCheckModelResponse:
|
||||
"""LLM response wrapping the raw assistant text plus token usage."""
|
||||
|
||||
text: str
|
||||
usage: FileOutputCheckUsage
|
||||
|
||||
|
||||
class OutputCheckModelInvoker(Protocol):
|
||||
"""Direct (non-streaming) LLM invocation for output check.
|
||||
|
||||
The contract is intentionally narrow: one prompt in, one assistant message
|
||||
out. The Agent Soul's model identity is passed explicitly so the executor
|
||||
stays agnostic of how callers resolve the agent's model config.
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
model_settings: Mapping[str, Any] | None = None,
|
||||
) -> OutputCheckModelResponse: ...
|
||||
|
||||
|
||||
# Recognized aliases for the file id key in a produced file payload. Mirrors
|
||||
# :data:`output_type_checker._FILE_ID_KEYS` so both stages handle the same
|
||||
# field set.
|
||||
_FILE_ID_KEYS: tuple[str, ...] = ("file_id", "upload_file_id", "tool_file_id")
|
||||
|
||||
# Verdict / reason parsing patterns. The prompt instructs the model to start
|
||||
# with ``VERDICT: PASS|FAIL`` followed by ``REASON: ...``; we tolerate
|
||||
# whitespace and case variations.
|
||||
_VERDICT_PATTERN = re.compile(r"VERDICT\s*:\s*(PASS|FAIL)", re.IGNORECASE)
|
||||
_REASON_PATTERN = re.compile(r"REASON\s*:\s*(.+)", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
_DEFAULT_MAX_CONTENT_CHARS = 32_000
|
||||
_TRUNCATION_NOTICE = "…[content truncated]"
|
||||
|
||||
|
||||
class FileOutputCheckExecutor:
|
||||
"""Run benchmark checks against every file output that opted in."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
content_loader: FileContentLoader,
|
||||
model_invoker: OutputCheckModelInvoker,
|
||||
max_content_chars: int = _DEFAULT_MAX_CONTENT_CHARS,
|
||||
) -> None:
|
||||
self._content_loader = content_loader
|
||||
self._model_invoker = model_invoker
|
||||
self._max_content_chars = max_content_chars
|
||||
|
||||
def check_all(
|
||||
self,
|
||||
*,
|
||||
declared_outputs: list[DeclaredOutputConfig],
|
||||
raw_output: Mapping[str, Any] | Any,
|
||||
tenant_id: str,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
model_settings: Mapping[str, Any] | None = None,
|
||||
) -> FileOutputCheckOutcome:
|
||||
"""Run benchmark checks for the file outputs that opted in.
|
||||
|
||||
``raw_output`` matches the shape of ``run_succeeded.data.output`` —
|
||||
normally a dict, but we widen the signature like
|
||||
:meth:`PerOutputTypeChecker.check` so a misbehaving backend cannot
|
||||
crash the executor. Non-mapping payloads yield an empty outcome
|
||||
(type-check already surfaced the failure).
|
||||
|
||||
Skips:
|
||||
- non-file declared outputs (silently — handled by type check)
|
||||
- file outputs without ``check.enabled``
|
||||
- file outputs whose produced value is missing (already flagged by
|
||||
type check; we do not want to surface a duplicate signal here)
|
||||
"""
|
||||
if not isinstance(raw_output, Mapping):
|
||||
return FileOutputCheckOutcome(results=())
|
||||
|
||||
results: list[FileOutputCheckResult] = []
|
||||
for declared in declared_outputs:
|
||||
if declared.type != DeclaredOutputType.FILE:
|
||||
continue
|
||||
if not (declared.check and declared.check.enabled):
|
||||
continue
|
||||
produced_value = raw_output.get(declared.name)
|
||||
if produced_value is None:
|
||||
results.append(
|
||||
FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason=f"Produced value for {declared.name!r} is missing.",
|
||||
skip_reason=FileOutputCheckSkipReason.PRODUCED_FILE_MISSING,
|
||||
)
|
||||
)
|
||||
continue
|
||||
results.append(
|
||||
self._check_one(
|
||||
declared=declared,
|
||||
produced_value=produced_value,
|
||||
tenant_id=tenant_id,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
)
|
||||
return FileOutputCheckOutcome(results=tuple(results))
|
||||
|
||||
def _check_one(
|
||||
self,
|
||||
*,
|
||||
declared: DeclaredOutputConfig,
|
||||
produced_value: Any,
|
||||
tenant_id: str,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
model_settings: Mapping[str, Any] | None,
|
||||
) -> FileOutputCheckResult:
|
||||
# ``declared.check`` is guaranteed non-None by the caller; access via
|
||||
# the model validator that ensured prompt + benchmark_file_ref exist.
|
||||
assert declared.check is not None
|
||||
assert declared.check.enabled
|
||||
check = declared.check
|
||||
|
||||
produced_file_id = self._extract_file_id(produced_value)
|
||||
if produced_file_id is None:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason="Produced value lacks a recognized file_id field.",
|
||||
skip_reason=FileOutputCheckSkipReason.PRODUCED_FILE_MISSING,
|
||||
)
|
||||
bench_ref = check.benchmark_file_ref or {}
|
||||
benchmark_file_id = self._extract_file_id(bench_ref)
|
||||
if benchmark_file_id is None:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason="benchmark_file_ref is missing a recognized file_id field.",
|
||||
skip_reason=FileOutputCheckSkipReason.BENCHMARK_FILE_NOT_ACCESSIBLE,
|
||||
)
|
||||
|
||||
produced = self._content_loader.load(file_id=produced_file_id, tenant_id=tenant_id)
|
||||
if produced is None:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason=f"Produced file {produced_file_id!r} is not accessible to tenant.",
|
||||
skip_reason=FileOutputCheckSkipReason.PRODUCED_FILE_MISSING,
|
||||
)
|
||||
if produced.unsupported:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason="Produced file type is not supported for output check.",
|
||||
skip_reason=FileOutputCheckSkipReason.UNSUPPORTED_FILE_FOR_OUTPUT_CHECK,
|
||||
)
|
||||
|
||||
benchmark = self._content_loader.load(file_id=benchmark_file_id, tenant_id=tenant_id)
|
||||
if benchmark is None:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason=f"Benchmark file {benchmark_file_id!r} is not accessible to tenant.",
|
||||
skip_reason=FileOutputCheckSkipReason.BENCHMARK_FILE_NOT_ACCESSIBLE,
|
||||
)
|
||||
if benchmark.unsupported:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason="Benchmark file type is not supported for output check.",
|
||||
skip_reason=FileOutputCheckSkipReason.UNSUPPORTED_FILE_FOR_OUTPUT_CHECK,
|
||||
)
|
||||
|
||||
benchmark_text, produced_text, truncated = self._truncate_for_budget(
|
||||
benchmark_text=benchmark.text, produced_text=produced.text
|
||||
)
|
||||
truncated = truncated or benchmark.truncated or produced.truncated
|
||||
|
||||
# ``check.prompt`` is guaranteed non-None by the model validator when
|
||||
# enabled=True, but we coerce defensively for older records.
|
||||
user_prompt = check.prompt or ""
|
||||
prompt = self._build_prompt(
|
||||
user_prompt=user_prompt,
|
||||
benchmark_text=benchmark_text,
|
||||
produced_text=produced_text,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._model_invoker.invoke(
|
||||
tenant_id=tenant_id,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
model_settings=model_settings,
|
||||
)
|
||||
except OutputCheckModelInvocationError as exc:
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=FileOutputCheckStatus.SKIPPED,
|
||||
reason=f"Model invocation failed: {exc}",
|
||||
skip_reason=FileOutputCheckSkipReason.MODEL_INVOCATION_ERROR,
|
||||
content_truncated=truncated,
|
||||
)
|
||||
|
||||
verdict, reason = self._parse_verdict(response.text)
|
||||
if verdict == "pass":
|
||||
status = FileOutputCheckStatus.PASSED
|
||||
elif verdict == "fail":
|
||||
status = FileOutputCheckStatus.FAILED
|
||||
else:
|
||||
# Indeterminate output. We treat it as FAIL so the orchestrator
|
||||
# gets a real signal; the raw model text is included in the
|
||||
# reason for debugging.
|
||||
status = FileOutputCheckStatus.FAILED
|
||||
reason = f"Indeterminate model response: {response.text.strip()[:300]}"
|
||||
|
||||
return FileOutputCheckResult(
|
||||
output_name=declared.name,
|
||||
status=status,
|
||||
reason=reason,
|
||||
usage=response.usage,
|
||||
content_truncated=truncated,
|
||||
)
|
||||
|
||||
def _truncate_for_budget(
|
||||
self,
|
||||
*,
|
||||
benchmark_text: str,
|
||||
produced_text: str,
|
||||
) -> tuple[str, str, bool]:
|
||||
"""Split the char budget equally between benchmark and produced.
|
||||
|
||||
Returns ``(benchmark_text, produced_text, truncated)``. Each half is
|
||||
capped at ``max_content_chars // 2`` so a single huge document cannot
|
||||
starve the other side.
|
||||
"""
|
||||
half = self._max_content_chars // 2
|
||||
truncated = False
|
||||
if len(benchmark_text) > half:
|
||||
benchmark_text = benchmark_text[:half] + _TRUNCATION_NOTICE
|
||||
truncated = True
|
||||
if len(produced_text) > half:
|
||||
produced_text = produced_text[:half] + _TRUNCATION_NOTICE
|
||||
truncated = True
|
||||
return benchmark_text, produced_text, truncated
|
||||
|
||||
@staticmethod
|
||||
def _build_prompt(*, user_prompt: str, benchmark_text: str, produced_text: str) -> str:
|
||||
return (
|
||||
"You are an output validator. The user has defined the following acceptance criteria:\n"
|
||||
"<criteria>\n"
|
||||
f"{user_prompt.strip()}\n"
|
||||
"</criteria>\n\n"
|
||||
"Below is the BENCHMARK file the produced output should be evaluated against:\n"
|
||||
"<benchmark>\n"
|
||||
f"{benchmark_text}\n"
|
||||
"</benchmark>\n\n"
|
||||
"Below is the PRODUCED file from the agent run:\n"
|
||||
"<produced>\n"
|
||||
f"{produced_text}\n"
|
||||
"</produced>\n\n"
|
||||
"Decide whether the PRODUCED file satisfies the criteria when compared to the BENCHMARK.\n"
|
||||
"Respond strictly in this format on two lines:\n"
|
||||
"VERDICT: PASS\n"
|
||||
"REASON: <one-sentence explanation>\n"
|
||||
"(or VERDICT: FAIL when the produced file does not meet the criteria)."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_verdict(text: str) -> tuple[Literal["pass", "fail", "unknown"], str]:
|
||||
verdict_match = _VERDICT_PATTERN.search(text)
|
||||
reason_match = _REASON_PATTERN.search(text)
|
||||
if verdict_match is None:
|
||||
return "unknown", text.strip()[:300]
|
||||
verdict_raw = verdict_match.group(1).lower()
|
||||
verdict: Literal["pass", "fail", "unknown"] = "pass" if verdict_raw == "pass" else "fail"
|
||||
if reason_match is not None:
|
||||
reason = reason_match.group(1).strip()
|
||||
# Reasons can run multi-line in some model outputs; first line is
|
||||
# usually the salient one.
|
||||
reason = reason.split("\n", 1)[0].strip()
|
||||
else:
|
||||
reason = "" if verdict == "pass" else "Model returned FAIL without a reason."
|
||||
return verdict, reason
|
||||
|
||||
@staticmethod
|
||||
def _extract_file_id(value: Any) -> str | None:
|
||||
if not isinstance(value, Mapping):
|
||||
return None
|
||||
for key in _FILE_ID_KEYS:
|
||||
candidate = value.get(key)
|
||||
if isinstance(candidate, str) and candidate:
|
||||
return candidate
|
||||
return None
|
||||
@ -1,133 +0,0 @@
|
||||
"""Production :class:`OutputCheckModelInvoker` backed by ``ModelManager``.
|
||||
|
||||
Stage 4 §6: the file-output check needs a direct, non-streaming LLM call that
|
||||
yields one assistant message plus token usage. Implementation choices:
|
||||
|
||||
* **No agent backend hop.** This is a one-shot evaluation, not an agentic
|
||||
loop. Going through the agent backend would conflate output-check usage
|
||||
with agent-run usage and introduce unnecessary protocol surface.
|
||||
* **Reuse Agent Soul's model identity.** Callers supply ``provider`` and
|
||||
``model_name`` from the same :class:`AgentSoulModelConfig` the agent itself
|
||||
uses; the check therefore inherits the tenant's existing model credentials
|
||||
and configuration without a separate setup.
|
||||
* **Bucket usage separately.** Returned ``FileOutputCheckUsage`` is later
|
||||
recorded under ``WorkflowNodeExecutionMetadata.output_check_usage`` per
|
||||
decision D-2; never merged with agent-run usage.
|
||||
|
||||
Any exception raised inside ``ModelInstance.invoke_llm`` (provider error,
|
||||
credential issue, network timeout, ...) is converted to a single
|
||||
:class:`OutputCheckModelInvocationError`. The executor catches that and emits
|
||||
a SKIPPED result tagged ``output_check_model_error`` so the surrounding retry
|
||||
/ fail-branch logic still proceeds deterministically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
from .output_check_executor import (
|
||||
FileOutputCheckUsage,
|
||||
OutputCheckModelInvocationError,
|
||||
OutputCheckModelResponse,
|
||||
)
|
||||
|
||||
# Resolves a tenant id to a fresh ``ModelManager``. Defined as a Callable
|
||||
# alias rather than a Protocol class so plain functions injected by tests
|
||||
# (e.g. ``lambda _: stub``) satisfy the type without subclassing.
|
||||
ModelManagerFactory = Callable[[str], ModelManager]
|
||||
|
||||
|
||||
class ModelRuntimeOutputCheckInvoker:
|
||||
"""Direct LLM invocation via the existing model_runtime stack.
|
||||
|
||||
A fresh :class:`ModelManager` is built per invocation by default so
|
||||
credential-cache staleness cannot leak across workflow runs. Tests can
|
||||
inject their own ``model_manager_factory`` to avoid touching the provider
|
||||
manager and DB.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_manager_factory: ModelManagerFactory | None = None,
|
||||
) -> None:
|
||||
self._factory: ModelManagerFactory = model_manager_factory or _default_model_manager_factory
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
model_settings: Mapping[str, Any] | None = None,
|
||||
) -> OutputCheckModelResponse:
|
||||
try:
|
||||
manager = self._factory(tenant_id)
|
||||
model_instance = manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=model_provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=[UserPromptMessage(content=prompt)],
|
||||
model_parameters=dict(model_settings or {}),
|
||||
stream=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise OutputCheckModelInvocationError(str(exc)) from exc
|
||||
|
||||
if not isinstance(result, LLMResult):
|
||||
# ``stream=False`` is documented to return LLMResult; if the
|
||||
# provider implementation breaks that contract surface it through
|
||||
# the same uniform error path rather than a cryptic AttributeError
|
||||
# later.
|
||||
raise OutputCheckModelInvocationError(
|
||||
f"Expected LLMResult from non-streaming invoke, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
text = _flatten_assistant_text(result)
|
||||
usage = _to_file_output_check_usage(result.usage)
|
||||
return OutputCheckModelResponse(text=text, usage=usage)
|
||||
|
||||
|
||||
def _default_model_manager_factory(tenant_id: str) -> ModelManager:
|
||||
return ModelManager.for_tenant(tenant_id)
|
||||
|
||||
|
||||
def _flatten_assistant_text(result: LLMResult) -> str:
|
||||
"""Extract a plain string from ``AssistantPromptMessage.content``.
|
||||
|
||||
The model runtime allows multimodal content lists; for output check we
|
||||
only ever expect a text response, but defensive flattening prevents an
|
||||
unexpected list payload from crashing the parser.
|
||||
"""
|
||||
content = result.message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for piece in content:
|
||||
piece_text = getattr(piece, "data", None) or getattr(piece, "text", None)
|
||||
if isinstance(piece_text, str):
|
||||
parts.append(piece_text)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _to_file_output_check_usage(usage: LLMUsage) -> FileOutputCheckUsage:
|
||||
"""Project an :class:`LLMUsage` into the executor's narrower shape."""
|
||||
return FileOutputCheckUsage(
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
total_price=usage.total_price,
|
||||
currency=usage.currency,
|
||||
latency_ms=int(usage.latency * 1000),
|
||||
)
|
||||
@ -4,7 +4,8 @@ from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, cast
|
||||
|
||||
from dify_agent.protocol import CreateRunRequest, ExecutionContext
|
||||
from dify_agent.layers.execution_context import DifyExecutionContextLayerConfig
|
||||
from dify_agent.protocol import CreateRunRequest
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendModelConfig,
|
||||
@ -105,16 +106,20 @@ class WorkflowAgentRuntimeRequestBuilder:
|
||||
request = self._request_builder.build_for_workflow_node(
|
||||
AgentBackendWorkflowNodeRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
model=agent_soul.model.model,
|
||||
user_id=context.dify_context.user_id,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
|
||||
),
|
||||
execution_context=ExecutionContext(
|
||||
# The execution-context layer is now the only public protocol
|
||||
# carrier for Dify tenant/user/run identifiers. ``user_id`` must
|
||||
# be forwarded here because downstream plugin-daemon provider and
|
||||
# tool clients read it from this layer rather than from any
|
||||
# parallel top-level request field.
|
||||
execution_context=DifyExecutionContextLayerConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
user_id=context.dify_context.user_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
workflow_id=context.workflow_id,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
|
||||
@ -1,149 +0,0 @@
|
||||
"""Production :class:`FileContentLoader` backed by ``upload_files`` + storage.
|
||||
|
||||
Stage 4 §6: the :class:`FileOutputCheckExecutor` needs to read both the
|
||||
benchmark file (operator-supplied) and the agent-produced file as text so the
|
||||
LLM evaluator can compare them. Both are stored in Dify's ``upload_files``
|
||||
table; this adapter:
|
||||
|
||||
1. resolves the ``file_id`` to an ``UploadFile`` row inside the caller's tenant
|
||||
(cross-tenant access returns ``None`` — never raises);
|
||||
2. classifies the file's extension as text-extractable or unsupported (image /
|
||||
archive / audio / video / executable are all treated as unsupported until
|
||||
the executor learns to feed vision input or downloads);
|
||||
3. delegates text extraction to :class:`ExtractProcessor.load_from_upload_file`
|
||||
so PDFs / Word / CSV / Markdown / HTML / Text all reuse the existing RAG
|
||||
pipeline rather than re-implementing decoders.
|
||||
|
||||
Any extraction failure (corrupt file, unsupported encoding, ETL backend down)
|
||||
becomes a ``LoadedFileContent`` with ``unsupported=True`` instead of an
|
||||
exception so the executor can map it to a deterministic ``SKIPPED`` result.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import DataError, SQLAlchemyError
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from models.model import UploadFile
|
||||
|
||||
from .output_check_executor import LoadedFileContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# File extensions explicitly rejected because text extraction either does not
|
||||
# apply (images, audio, video) or yields no useful comparison material
|
||||
# (archives, executables). Anything outside this set falls through to
|
||||
# ``ExtractProcessor`` which has its own fallback to a text decoder.
|
||||
_UNSUPPORTED_EXTENSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
# Images — handled later by a vision-capable code path; deferred to
|
||||
# stage 4.1 per design doc §6.3.
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
".ico",
|
||||
# Audio / video.
|
||||
".mp3",
|
||||
".wav",
|
||||
".ogg",
|
||||
".flac",
|
||||
".mp4",
|
||||
".m4a",
|
||||
".mov",
|
||||
".webm",
|
||||
".avi",
|
||||
".mkv",
|
||||
# Archives.
|
||||
".zip",
|
||||
".tar",
|
||||
".gz",
|
||||
".bz2",
|
||||
".7z",
|
||||
".rar",
|
||||
# Executables / native binaries.
|
||||
".exe",
|
||||
".dll",
|
||||
".so",
|
||||
".dylib",
|
||||
".bin",
|
||||
".dmg",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class UploadFileContentLoader:
|
||||
"""Resolve an ``upload_files`` row → extracted plain text content.
|
||||
|
||||
Returns ``None`` (not an exception) for unknown / cross-tenant / DB-error
|
||||
cases so the executor can produce a deterministic SKIPPED result. Returns
|
||||
a ``LoadedFileContent`` with ``unsupported=True`` when the file format is
|
||||
recognized but not text-extractable (e.g. image, archive). Returns a
|
||||
populated ``LoadedFileContent`` on success.
|
||||
"""
|
||||
|
||||
def load(self, *, file_id: str, tenant_id: str) -> LoadedFileContent | None:
|
||||
if not file_id or not tenant_id:
|
||||
return None
|
||||
try:
|
||||
UUID(file_id)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
try:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id))
|
||||
except (DataError, SQLAlchemyError):
|
||||
logger.warning("UploadFileContentLoader: DB error while resolving file_id=%s", file_id, exc_info=True)
|
||||
return None
|
||||
|
||||
if upload_file is None or upload_file.tenant_id != tenant_id:
|
||||
return None
|
||||
|
||||
extension = self._extension_of(upload_file)
|
||||
if extension in _UNSUPPORTED_EXTENSIONS:
|
||||
return LoadedFileContent(text="", unsupported=True)
|
||||
|
||||
try:
|
||||
extracted = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
|
||||
except Exception:
|
||||
# Any failure inside the extraction pipeline (corrupt file,
|
||||
# missing storage object, ETL backend down, plugin error...) is
|
||||
# surfaced as "unsupported" so the executor produces a SKIPPED
|
||||
# result rather than failing the whole node.
|
||||
logger.warning(
|
||||
"UploadFileContentLoader: extraction failed for file_id=%s ext=%s",
|
||||
file_id,
|
||||
extension,
|
||||
exc_info=True,
|
||||
)
|
||||
return LoadedFileContent(text="", unsupported=True)
|
||||
|
||||
if not isinstance(extracted, str):
|
||||
# ExtractProcessor.load_from_upload_file returns ``list[Document]``
|
||||
# only when ``return_text=False``; defensive guard.
|
||||
return LoadedFileContent(text="", unsupported=True)
|
||||
|
||||
return LoadedFileContent(text=extracted)
|
||||
|
||||
@staticmethod
|
||||
def _extension_of(upload_file: UploadFile) -> str:
|
||||
"""Lowercased ``.ext`` form of :attr:`UploadFile.extension`.
|
||||
|
||||
``UploadFile.extension`` is stored without the leading dot (e.g.
|
||||
``"pdf"``), but the unsupported-set is keyed by ``".pdf"`` form so
|
||||
we normalize before lookup. Empty extension stays empty.
|
||||
"""
|
||||
raw = (upload_file.extension or "").strip().lower()
|
||||
if not raw:
|
||||
return ""
|
||||
return raw if raw.startswith(".") else f".{raw}"
|
||||
@ -45,6 +45,7 @@ SPEC_TARGETS: tuple[SpecTarget, ...] = (
|
||||
SpecTarget(route="/console/api/swagger.json", filename="console-swagger.json", namespace="console"),
|
||||
SpecTarget(route="/api/swagger.json", filename="web-swagger.json", namespace="web"),
|
||||
SpecTarget(route="/v1/swagger.json", filename="service-swagger.json", namespace="service"),
|
||||
SpecTarget(route="/openapi/v1/swagger.json", filename="openapi-swagger.json", namespace="openapi"),
|
||||
)
|
||||
|
||||
|
||||
@ -161,6 +162,8 @@ def create_spec_app() -> Flask:
|
||||
|
||||
from controllers.console import bp as console_bp
|
||||
from controllers.console import console_ns
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.web import bp as web_bp
|
||||
@ -169,8 +172,9 @@ def create_spec_app() -> Flask:
|
||||
app.register_blueprint(console_bp)
|
||||
app.register_blueprint(web_bp)
|
||||
app.register_blueprint(service_api_bp)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
for namespace in (console_ns, web_ns, service_api_ns):
|
||||
for namespace in (console_ns, web_ns, service_api_ns, openapi_ns):
|
||||
for api in namespace.apis:
|
||||
_materialize_inline_model_definitions(api)
|
||||
|
||||
@ -201,6 +205,13 @@ def _registered_models(namespace: str) -> dict[str, object]:
|
||||
for api in service_api_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
if namespace == "openapi":
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
models = dict(openapi_ns.models)
|
||||
for api in openapi_ns.apis:
|
||||
models.update(api.models)
|
||||
return models
|
||||
|
||||
raise ValueError(f"unknown Swagger namespace: {namespace}")
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
|
||||
OPENAPI_MAX_AGE_SECONDS: int = 600
|
||||
|
||||
|
||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.mcp import bp as mcp_bp
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
@ -41,6 +44,23 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
if dify_config.OPENAPI_ENABLED:
|
||||
# User-scoped programmatic API. Default empty allowlist = same-origin
|
||||
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
|
||||
# integrations. supports_credentials so cookie-authed approve/deny
|
||||
# work; cross-origin OPTIONS without an allowed origin will fail
|
||||
# the same as on the console blueprint.
|
||||
_apply_cors_once(
|
||||
openapi_bp,
|
||||
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={
|
||||
|
||||
@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token, extract_webapp_passport
|
||||
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
# approval-context) sit under @login_required and authenticate via
|
||||
# the console session cookie. Cookie-only on purpose — bearer
|
||||
# tokens (dfoa_/dfoe_) live on the Authorization header and are
|
||||
# validated by AppPipeline, not flask-login.
|
||||
cookie_token = extract_console_cookie_token(request)
|
||||
if not cookie_token:
|
||||
return None
|
||||
try:
|
||||
decoded = PassportService().verify(cookie_token)
|
||||
except Exception:
|
||||
return None
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
23
api/extensions/ext_oauth_bearer.py
Normal file
23
api/extensions/ext_oauth_bearer.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import build_and_bind
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.ENABLE_OAUTH_BEARER
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
# scoped_session isn't a context manager; request teardown closes it.
|
||||
def session_factory():
|
||||
return db.session
|
||||
|
||||
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||
205
api/libs/device_flow_security.py
Normal file
205
api/libs/device_flow_security.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
|
||||
from flask import Blueprint
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
from libs.token import is_secure
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# enterprise_only decorator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
||||
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
||||
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
||||
|
||||
|
||||
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||
responses don't consume the bucket.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status not in _EE_ENABLED_STATUSES:
|
||||
raise NotFound()
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# approval_grant cookie
|
||||
# ============================================================================
|
||||
|
||||
|
||||
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
||||
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ApprovalGrantClaims:
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
user_code: str
|
||||
nonce: str
|
||||
csrf_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def mint_approval_grant(
|
||||
*,
|
||||
keyset: jws.KeySet,
|
||||
iss: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str,
|
||||
user_code: str,
|
||||
) -> tuple[str, ApprovalGrantClaims]:
|
||||
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
nonce = _random_opaque()
|
||||
csrf_token = _random_opaque()
|
||||
|
||||
payload = {
|
||||
"iss": iss,
|
||||
"subject_email": subject_email,
|
||||
"subject_issuer": subject_issuer,
|
||||
"user_code": user_code,
|
||||
"nonce": nonce,
|
||||
"csrf_token": csrf_token,
|
||||
}
|
||||
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
|
||||
return token, ApprovalGrantClaims(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
user_code=user_code,
|
||||
nonce=nonce,
|
||||
csrf_token=csrf_token,
|
||||
expires_at=exp,
|
||||
)
|
||||
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
# lazy import: breaks libs → controllers cycle
|
||||
from controllers.openapi._models import ApprovalGrantClaimsPayload
|
||||
|
||||
raw = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
try:
|
||||
parsed = ApprovalGrantClaimsPayload.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise jws.VerifyError(f"claim shape invalid: {e}") from e
|
||||
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=parsed.subject_email,
|
||||
subject_issuer=parsed.subject_issuer,
|
||||
user_code=parsed.user_code,
|
||||
nonce=parsed.nonce,
|
||||
csrf_token=parsed.csrf_token,
|
||||
expires_at=datetime.fromtimestamp(raw["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||
silently drop the cookie.
|
||||
"""
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": value,
|
||||
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": "",
|
||||
"max_age": 0,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def _random_opaque() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Anti-framing headers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_ANTI_FRAMING_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||
}
|
||||
|
||||
|
||||
def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
@ -76,6 +76,7 @@ def register_external_error_handlers(api: Api):
|
||||
|
||||
def handle_value_error(e: ValueError):
|
||||
got_request_exception.send(current_app, exception=e)
|
||||
current_app.logger.exception("value_error in request handler")
|
||||
status_code = 400
|
||||
data = {"code": "invalid_param", "message": str(e), "status": status_code}
|
||||
return data, status_code
|
||||
|
||||
@ -595,3 +595,18 @@ class RateLimiter:
|
||||
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
def seconds_until_available(self, email: str) -> int:
|
||||
"""Seconds until the oldest in-window entry expires, freeing a slot.
|
||||
|
||||
Defensive floor of 1 second. Caller should only invoke this after
|
||||
is_rate_limited() returned True.
|
||||
"""
|
||||
key = self._get_key(email)
|
||||
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
|
||||
if not oldest:
|
||||
return 1
|
||||
_member, score = oldest[0]
|
||||
free_at = int(score) + self.time_window
|
||||
remaining = free_at - int(time.time())
|
||||
return max(remaining, 1)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user