mirror of
https://github.com/langgenius/dify.git
synced 2026-05-29 05:07:55 +08:00
Compare commits
2 Commits
fix/device
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
| 2bff76eee9 | |||
| 50d7fceb1b |
@ -1,15 +0,0 @@
|
||||
**/node_modules
|
||||
**/.pnpm-store
|
||||
**/dist
|
||||
**/.next
|
||||
**/.turbo
|
||||
**/.cache
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
**/.mypy_cache
|
||||
**/.ruff_cache
|
||||
.git
|
||||
.github
|
||||
*.md
|
||||
!web/README.md
|
||||
!api/README.md
|
||||
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -5,7 +5,3 @@
|
||||
# them.
|
||||
|
||||
*.sh text eol=lf
|
||||
|
||||
# Codegen output must stay byte-identical across platforms so
|
||||
# `pnpm tree:check` in CI does not trip on CRLF rewrites.
|
||||
*.generated.ts text eol=lf
|
||||
|
||||
5
.github/CODEOWNERS
vendored
5
.github/CODEOWNERS
vendored
@ -18,10 +18,6 @@
|
||||
# Docs
|
||||
/docs/ @crazywoola
|
||||
|
||||
# CLI
|
||||
/cli/ @langgenius/maintainers
|
||||
/.github/workflows/cli-tests.yml @langgenius/maintainers
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
/api/ @QuantumGhost
|
||||
|
||||
@ -166,7 +162,6 @@
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
/web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
/web/app/components/develop/template/*.mdx @JzoNgKVO @iamjoel @RiskeyL
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
/web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
|
||||
111
.github/dependabot.yml
vendored
111
.github/dependabot.yml
vendored
@ -110,114 +110,3 @@ updates:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "uv"
|
||||
directory: "/api"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 10
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
flask:
|
||||
patterns:
|
||||
- "flask"
|
||||
- "flask-*"
|
||||
- "werkzeug"
|
||||
- "gunicorn"
|
||||
google:
|
||||
patterns:
|
||||
- "google-*"
|
||||
- "googleapis-*"
|
||||
opentelemetry:
|
||||
patterns:
|
||||
- "opentelemetry-*"
|
||||
pydantic:
|
||||
patterns:
|
||||
- "pydantic"
|
||||
- "pydantic-*"
|
||||
llm:
|
||||
patterns:
|
||||
- "langfuse"
|
||||
- "langsmith"
|
||||
- "litellm"
|
||||
- "mlflow*"
|
||||
- "opik"
|
||||
- "weave*"
|
||||
- "arize*"
|
||||
- "tiktoken"
|
||||
- "transformers"
|
||||
database:
|
||||
patterns:
|
||||
- "sqlalchemy"
|
||||
- "psycopg2*"
|
||||
- "psycogreen"
|
||||
- "redis*"
|
||||
- "alembic*"
|
||||
storage:
|
||||
patterns:
|
||||
- "boto3*"
|
||||
- "botocore*"
|
||||
- "azure-*"
|
||||
- "bce-*"
|
||||
- "cos-python-*"
|
||||
- "esdk-obs-*"
|
||||
- "google-cloud-storage"
|
||||
- "opendal"
|
||||
- "oss2"
|
||||
- "supabase*"
|
||||
- "tos*"
|
||||
vdb:
|
||||
patterns:
|
||||
- "alibabacloud*"
|
||||
- "chromadb"
|
||||
- "clickhouse-*"
|
||||
- "clickzetta-*"
|
||||
- "couchbase"
|
||||
- "elasticsearch"
|
||||
- "opensearch-py"
|
||||
- "oracledb"
|
||||
- "pgvect*"
|
||||
- "pymilvus"
|
||||
- "pymochow"
|
||||
- "pyobvector"
|
||||
- "qdrant-client"
|
||||
- "intersystems-*"
|
||||
- "tablestore"
|
||||
- "tcvectordb"
|
||||
- "tidb-vector"
|
||||
- "upstash-*"
|
||||
- "volcengine-*"
|
||||
- "weaviate-*"
|
||||
- "xinference-*"
|
||||
- "mo-vector"
|
||||
- "mysql-connector-*"
|
||||
dev:
|
||||
patterns:
|
||||
- "coverage"
|
||||
- "dotenv-linter"
|
||||
- "faker"
|
||||
- "lxml-stubs"
|
||||
- "basedpyright"
|
||||
- "ruff"
|
||||
- "pytest*"
|
||||
- "types-*"
|
||||
- "boto3-stubs"
|
||||
- "hypothesis"
|
||||
- "pandas-stubs"
|
||||
- "scipy-stubs"
|
||||
- "import-linter"
|
||||
- "celery-types"
|
||||
- "mypy*"
|
||||
- "pyrefly"
|
||||
python-packages:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
target-branch: "lts/1.13.x"
|
||||
open-pull-requests-limit: 5
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
groups:
|
||||
github-actions-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
88
.github/workflows/cli-release.yml
vendored
88
.github/workflows/cli-release.yml
vendored
@ -1,88 +0,0 @@
|
||||
name: CLI Release
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- 'difyctl-v*'
|
||||
|
||||
concurrency:
|
||||
group: cli-release-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: build standalone binaries (all targets)
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: write
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@4bc047ad259df6fc24a6c9b0f9a0cb08cf17fbe5 # v2.0.2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Read cli/package.json
|
||||
id: manifest
|
||||
run: |
|
||||
version=$(node -p "require('./package.json').version")
|
||||
channel=$(node -p "require('./package.json').difyctl.channel")
|
||||
minDify=$(node -p "require('./package.json').difyctl.compat.minDify")
|
||||
maxDify=$(node -p "require('./package.json').difyctl.compat.maxDify")
|
||||
{
|
||||
echo "version=$version"
|
||||
echo "channel=$channel"
|
||||
echo "minDify=$minDify"
|
||||
echo "maxDify=$maxDify"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Validate manifest
|
||||
run: scripts/release-validate-manifest.sh
|
||||
|
||||
- name: Install cross-arch native prebuilds
|
||||
# Re-installs node_modules with every @napi-rs/keyring platform variant
|
||||
# so `bun build --compile` can embed the right .node into each target.
|
||||
working-directory: ./
|
||||
run: NPM_CONFIG_USERCONFIG="$PWD/cli/scripts/cross-arch.npmrc" pnpm install --frozen-lockfile
|
||||
|
||||
- name: Compile standalone binaries (all targets)
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
DIFYCTL_CHANNEL: ${{ steps.manifest.outputs.channel }}
|
||||
DIFYCTL_MIN_DIFY: ${{ steps.manifest.outputs.minDify }}
|
||||
DIFYCTL_MAX_DIFY: ${{ steps.manifest.outputs.maxDify }}
|
||||
run: |
|
||||
DIFYCTL_COMMIT="$(git rev-parse HEAD)" \
|
||||
DIFYCTL_BUILD_DATE="$(git log -1 --format=%cI HEAD)" \
|
||||
pnpm build:bin
|
||||
|
||||
- name: Generate sha256 checksum file
|
||||
env:
|
||||
CLI_VERSION: ${{ steps.manifest.outputs.version }}
|
||||
run: scripts/release-write-checksums.sh
|
||||
|
||||
- name: Publish GitHub Release
|
||||
uses: softprops/action-gh-release@72f2c25fcb47643c292f7107632f7a47c1df5cd8 # v2.3.2
|
||||
with:
|
||||
tag_name: difyctl-v${{ steps.manifest.outputs.version }}
|
||||
name: difyctl ${{ steps.manifest.outputs.version }}
|
||||
prerelease: ${{ steps.manifest.outputs.channel != 'stable' }}
|
||||
generate_release_notes: true
|
||||
fail_on_unmatched_files: true
|
||||
files: |
|
||||
cli/dist/bin/difyctl-v*
|
||||
60
.github/workflows/cli-smoke.yml
vendored
60
.github/workflows/cli-smoke.yml
vendored
@ -1,60 +0,0 @@
|
||||
name: CLI Smoke (live dify)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dify_version:
|
||||
description: "Dify image tag to test against (e.g. 1.7.0)"
|
||||
type: string
|
||||
required: true
|
||||
cli_ref:
|
||||
description: "Git ref to build the cli from (default: current branch)"
|
||||
type: string
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
smoke:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout cli ref
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
ref: ${{ inputs.cli_ref || github.ref }}
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Bring up dify
|
||||
env:
|
||||
DIFY_VERSION: ${{ inputs.dify_version }}
|
||||
run: |
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
DIFY_API_IMAGE_TAG="$DIFY_VERSION" \
|
||||
DIFY_WEB_IMAGE_TAG="$DIFY_VERSION" \
|
||||
docker compose up -d api worker web db redis
|
||||
for i in $(seq 1 60); do
|
||||
if curl -fsS http://localhost:5001/health >/dev/null 2>&1; then
|
||||
echo "dify api ready after ${i}s"
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run smoke against live dify
|
||||
working-directory: ./cli
|
||||
run: pnpm exec tsx scripts/run-smoke.ts --base-url http://localhost:5001
|
||||
|
||||
- name: Dump dify logs on failure
|
||||
if: failure()
|
||||
run: |
|
||||
cd docker
|
||||
docker compose logs api worker web --tail=200
|
||||
50
.github/workflows/cli-tests.yml
vendored
50
.github/workflows/cli-tests.yml
vendored
@ -1,50 +0,0 @@
|
||||
name: CLI Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
secrets:
|
||||
CODECOV_TOKEN:
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: cli-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: CLI Tests (${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [depot-ubuntu-24.04, windows-latest, macos-latest]
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./cli
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: CI pipeline (typecheck, lint, coverage, build)
|
||||
run: pnpm ci
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' && matrix.os == 'depot-ubuntu-24.04' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: cli/coverage
|
||||
flags: cli
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
73
.github/workflows/main-ci.yml
vendored
73
.github/workflows/main-ci.yml
vendored
@ -42,7 +42,6 @@ jobs:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
cli-changed: ${{ steps.changes.outputs.cli }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
@ -63,18 +62,6 @@ jobs:
|
||||
- 'docker/generate_docker_compose'
|
||||
- 'docker/ssrf_proxy/**'
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
cli:
|
||||
- 'cli/**'
|
||||
- 'packages/tsconfig/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- 'eslint.config.mjs'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/cli-tests.yml'
|
||||
- '.github/workflows/cli-docker-build.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
@ -197,66 +184,6 @@ jobs:
|
||||
echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
cli-tests-run:
|
||||
name: Run CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed == 'true'
|
||||
uses: ./.github/workflows/cli-tests.yml
|
||||
secrets: inherit
|
||||
|
||||
cli-tests-skip:
|
||||
name: Skip CLI Tests
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.cli-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped CLI tests
|
||||
run: echo "No CLI-related changes detected; skipping CLI tests."
|
||||
|
||||
cli-tests:
|
||||
name: CLI Tests
|
||||
if: ${{ always() }}
|
||||
needs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
- cli-tests-run
|
||||
- cli-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize CLI Tests status
|
||||
env:
|
||||
SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }}
|
||||
TESTS_CHANGED: ${{ needs.check-changes.outputs.cli-changed }}
|
||||
RUN_RESULT: ${{ needs.cli-tests-run.result }}
|
||||
SKIP_RESULT: ${{ needs.cli-tests-skip.result }}
|
||||
run: |
|
||||
if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then
|
||||
echo "CLI tests were skipped because this workflow run duplicated a successful or newer run."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "$TESTS_CHANGED" == 'true' ]]; then
|
||||
if [[ "$RUN_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests ran successfully."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were required but finished with result: $RUN_RESULT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$SKIP_RESULT" == 'success' ]]; then
|
||||
echo "CLI tests were skipped because no CLI-related files changed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "CLI tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2
|
||||
exit 1
|
||||
|
||||
web-tests-run:
|
||||
name: Run Web Tests
|
||||
needs:
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@ -115,12 +115,6 @@ venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# cli/ has a src/env/ module (DIFY_* registry) — don't treat it as a venv
|
||||
!/cli/src/env/
|
||||
!/cli/src/commands/env/
|
||||
# cli/scripts/lib/ holds TS build helpers (resolve-buildinfo etc.) — don't treat as Python lib/
|
||||
!/cli/scripts/lib/
|
||||
.conda/
|
||||
|
||||
# Spyder project settings
|
||||
@ -253,9 +247,8 @@ scripts/stress-test/reports/
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
*.local.toml
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.context/
|
||||
.context/*
|
||||
.eslintcache
|
||||
|
||||
@ -27,7 +27,7 @@ COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev --no-editable
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# production stage
|
||||
FROM base AS production
|
||||
|
||||
@ -159,7 +159,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -204,7 +203,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
@ -223,11 +221,10 @@ def initialize_extensions(app: DifyApp):
|
||||
|
||||
def create_migrations_app() -> DifyApp:
|
||||
app = create_flask_app_with_configs()
|
||||
from extensions import ext_commands, ext_database, ext_migrate
|
||||
from extensions import ext_database, ext_migrate
|
||||
|
||||
# Initialize only required extensions
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init_app(app)
|
||||
ext_commands.init_app(app)
|
||||
|
||||
return app
|
||||
|
||||
@ -30,23 +30,19 @@ from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
|
||||
from clients.agent_backend.request_builder import (
|
||||
AGENT_SOUL_PROMPT_LAYER_ID,
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID,
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
CleanupLayerSpec,
|
||||
extract_cleanup_layer_specs,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_SOUL_PROMPT_LAYER_ID",
|
||||
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
|
||||
"DIFY_PLUGIN_TOOLS_LAYER_ID",
|
||||
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
|
||||
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
|
||||
"WORKFLOW_USER_PROMPT_LAYER_ID",
|
||||
"AgentBackendError",
|
||||
@ -70,11 +66,9 @@ __all__ = [
|
||||
"AgentBackendTransportError",
|
||||
"AgentBackendValidationError",
|
||||
"AgentBackendWorkflowNodeRunInput",
|
||||
"CleanupLayerSpec",
|
||||
"DifyAgentBackendRunClient",
|
||||
"FakeAgentBackendRunClient",
|
||||
"FakeAgentBackendScenario",
|
||||
"create_agent_backend_run_client",
|
||||
"extract_cleanup_layer_specs",
|
||||
"redact_for_agent_backend_log",
|
||||
]
|
||||
|
||||
@ -20,8 +20,6 @@ from dify_agent.protocol import (
|
||||
RunEvent,
|
||||
RunFailedEvent,
|
||||
RunFailedEventData,
|
||||
RunPausedEvent,
|
||||
RunPausedEventData,
|
||||
RunStartedEvent,
|
||||
RunStatusResponse,
|
||||
RunSucceededEvent,
|
||||
@ -36,7 +34,6 @@ class FakeAgentBackendScenario(StrEnum):
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class FakeAgentBackendRunClient:
|
||||
@ -92,13 +89,6 @@ class FakeAgentBackendRunClient:
|
||||
updated_at=_FIXED_TIME,
|
||||
error="fake failure",
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return RunStatusResponse(
|
||||
run_id=run_id,
|
||||
status="paused",
|
||||
created_at=_FIXED_TIME,
|
||||
updated_at=_FIXED_TIME,
|
||||
)
|
||||
|
||||
def _events(self, run_id: str) -> tuple[RunEvent, ...]:
|
||||
match self.scenario:
|
||||
@ -125,17 +115,3 @@ class FakeAgentBackendRunClient:
|
||||
data=RunFailedEventData(error="fake failure", reason="unit_test"),
|
||||
),
|
||||
)
|
||||
case FakeAgentBackendScenario.PAUSED:
|
||||
return (
|
||||
RunStartedEvent(id="1-0", run_id=run_id, created_at=_FIXED_TIME),
|
||||
RunPausedEvent(
|
||||
id="2-0",
|
||||
run_id=run_id,
|
||||
created_at=_FIXED_TIME,
|
||||
data=RunPausedEventData(
|
||||
reason="human_input_required",
|
||||
message="Agent requested human input.",
|
||||
session_snapshot=CompositorSessionSnapshot(layers=[]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@ -4,37 +4,29 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
|
||||
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
|
||||
protocol has a single owner. API-only context such as Agent Soul vs workflow job
|
||||
prompt is preserved in layer names and metadata until the dedicated product
|
||||
schemas land in later phases. Dify-owned execution identifiers are emitted as an
|
||||
explicit ``dify.execution_context`` layer so the run request stays fully
|
||||
composition-driven.
|
||||
schemas land in later phases.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, cast
|
||||
from typing import ClassVar
|
||||
|
||||
from agenton.compositor import CompositorSessionSnapshot
|
||||
from agenton.compositor.schemas import LayerSessionSnapshot
|
||||
from agenton.layers import ExitIntent
|
||||
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
|
||||
from agenton_collections.layers.pydantic_ai import PYDANTIC_AI_HISTORY_LAYER_TYPE_ID
|
||||
from dify_agent.layers.dify_plugin import (
|
||||
DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
DifyPluginCredentialValue,
|
||||
DifyPluginLayerConfig,
|
||||
DifyPluginLLMLayerConfig,
|
||||
DifyPluginToolsLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.execution_context import (
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
DifyExecutionContextLayerConfig,
|
||||
)
|
||||
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
|
||||
from dify_agent.protocol import (
|
||||
DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
DIFY_AGENT_MODEL_LAYER_ID,
|
||||
DIFY_AGENT_OUTPUT_LAYER_ID,
|
||||
CreateRunRequest,
|
||||
ExecutionContext,
|
||||
LayerExitSignals,
|
||||
RunComposition,
|
||||
RunLayerSpec,
|
||||
@ -45,94 +37,17 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
|
||||
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
|
||||
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
|
||||
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
|
||||
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
|
||||
DIFY_PLUGIN_TOOLS_LAYER_ID = "tools"
|
||||
|
||||
# Layer types that hold credentials in their per-run config. These are excluded
|
||||
# from the cleanup-replay composition (and from the snapshot that is sent with
|
||||
# the cleanup request) because we deliberately do not persist plaintext
|
||||
# credentials between runs.
|
||||
_CLEANUP_EXCLUDED_LAYER_TYPES: tuple[str, ...] = (
|
||||
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
)
|
||||
|
||||
|
||||
class CleanupLayerSpec(BaseModel):
|
||||
"""One layer node replayed by an Agent backend cleanup-only run.
|
||||
|
||||
Cleanup composition cannot include credential-bearing plugin layers, so we
|
||||
persist only the non-plugin layer specs together with the original config.
|
||||
Storing the config (rather than just ``name``/``type``) means cleanup does
|
||||
not depend on the original build-time inputs being re-derivable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
deps: dict[str, str] = Field(default_factory=dict)
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
config: JsonValue = None
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def extract_cleanup_layer_specs(composition: RunComposition) -> list[CleanupLayerSpec]:
|
||||
"""Project the in-flight composition into the persistable cleanup spec list.
|
||||
|
||||
Plugin layers are intentionally dropped (their configs hold credentials and
|
||||
the lifecycle contract says "do not include an LLM layer" during cleanup).
|
||||
The filtered names must later drive snapshot filtering so the agenton
|
||||
compositor's name-order check still passes for the cleanup run.
|
||||
"""
|
||||
excluded = set(_CLEANUP_EXCLUDED_LAYER_TYPES)
|
||||
specs: list[CleanupLayerSpec] = []
|
||||
for layer in composition.layers:
|
||||
if layer.type in excluded:
|
||||
continue
|
||||
config_value: JsonValue = None
|
||||
if isinstance(layer.config, BaseModel):
|
||||
config_value = layer.config.model_dump(mode="json", warnings=False)
|
||||
else:
|
||||
# ``RunLayerSpec.config`` is typed as ``LayerConfigInput`` which
|
||||
# includes ``Mapping[str, object] | bytes``. In the cleanup-replay
|
||||
# pipeline our builder only emits BaseModel-derived configs or
|
||||
# ``None``, so the wider input alias narrows safely here.
|
||||
config_value = cast(JsonValue, layer.config)
|
||||
specs.append(
|
||||
CleanupLayerSpec(
|
||||
name=layer.name,
|
||||
type=layer.type,
|
||||
deps=dict(layer.deps),
|
||||
metadata=dict(layer.metadata),
|
||||
config=config_value,
|
||||
)
|
||||
)
|
||||
return specs
|
||||
|
||||
|
||||
def _filter_snapshot_to_specs(
|
||||
snapshot: CompositorSessionSnapshot,
|
||||
specs: list[CleanupLayerSpec],
|
||||
) -> CompositorSessionSnapshot:
|
||||
"""Keep only snapshot layers whose names appear in the cleanup spec list.
|
||||
|
||||
The agenton compositor rejects a snapshot whose layer-name sequence does
|
||||
not match the active composition exactly. Cleanup-replay drops plugin
|
||||
layers, so we must drop the matching snapshot entries here.
|
||||
"""
|
||||
kept_names = {spec.name for spec in specs}
|
||||
filtered_layers: list[LayerSessionSnapshot] = [layer for layer in snapshot.layers if layer.name in kept_names]
|
||||
if len(filtered_layers) == len(snapshot.layers):
|
||||
return snapshot
|
||||
return CompositorSessionSnapshot(schema_version=snapshot.schema_version, layers=filtered_layers)
|
||||
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
|
||||
|
||||
|
||||
class AgentBackendModelConfig(BaseModel):
|
||||
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
|
||||
|
||||
tenant_id: str
|
||||
plugin_id: str
|
||||
model_provider: str
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
@ -140,14 +55,10 @@ class AgentBackendModelConfig(BaseModel):
|
||||
|
||||
|
||||
class AgentBackendOutputConfig(BaseModel):
|
||||
"""API-side structured output declaration for the conventional output layer.
|
||||
|
||||
The structured-output tool name is fixed to ``final_output`` inside
|
||||
``dify_agent.layers.output`` so callers only control the JSON Schema plus
|
||||
optional description/strictness metadata.
|
||||
"""
|
||||
"""API-side structured output declaration for the conventional output layer."""
|
||||
|
||||
json_schema: dict[str, JsonValue]
|
||||
name: str = "final_result"
|
||||
description: str | None = None
|
||||
strict: bool | None = None
|
||||
|
||||
@ -158,17 +69,15 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
|
||||
|
||||
model: AgentBackendModelConfig
|
||||
execution_context: DifyExecutionContextLayerConfig
|
||||
execution_context: ExecutionContext
|
||||
workflow_node_job_prompt: str
|
||||
user_prompt: str
|
||||
agent_soul_prompt: str | None = None
|
||||
purpose: RunPurpose = "workflow_node"
|
||||
idempotency_key: str | None = None
|
||||
output: AgentBackendOutputConfig | None = None
|
||||
tools: DifyPluginToolsLayerConfig | None = None
|
||||
session_snapshot: CompositorSessionSnapshot | None = None
|
||||
include_history: bool = True
|
||||
suspend_on_exit: bool = True
|
||||
suspend_on_exit: bool = False
|
||||
metadata: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
||||
@ -184,50 +93,6 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
|
||||
class AgentBackendRunRequestBuilder:
|
||||
"""Converts API product state into the public ``dify-agent`` run protocol."""
|
||||
|
||||
def build_cleanup_request(
|
||||
self,
|
||||
*,
|
||||
session_snapshot: CompositorSessionSnapshot,
|
||||
composition_layer_specs: list[CleanupLayerSpec],
|
||||
idempotency_key: str | None = None,
|
||||
metadata: dict[str, JsonValue] | None = None,
|
||||
) -> CreateRunRequest:
|
||||
"""Build a lifecycle-only cleanup request that replays the prior layers.
|
||||
|
||||
The agenton compositor enforces that the session snapshot's layer names
|
||||
match the active composition in order, so cleanup must replay the same
|
||||
non-plugin layer graph that produced the snapshot. Plugin layers
|
||||
(``dify.plugin.llm``, ``dify.plugin.tools``) are excluded from both the
|
||||
composition and the snapshot before submission because their configs
|
||||
require credentials that are not persisted between runs.
|
||||
"""
|
||||
if not composition_layer_specs:
|
||||
raise ValueError(
|
||||
"build_cleanup_request requires composition_layer_specs; an empty "
|
||||
"composition would fail the agent backend's snapshot validation."
|
||||
)
|
||||
request_metadata = dict(metadata or {})
|
||||
request_metadata["agent_backend_lifecycle"] = "session_cleanup"
|
||||
layers = [
|
||||
RunLayerSpec(
|
||||
name=spec.name,
|
||||
type=spec.type,
|
||||
deps=dict(spec.deps),
|
||||
metadata=dict(spec.metadata),
|
||||
config=spec.config,
|
||||
)
|
||||
for spec in composition_layer_specs
|
||||
]
|
||||
filtered_snapshot = _filter_snapshot_to_specs(session_snapshot, composition_layer_specs)
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
purpose="workflow_node",
|
||||
idempotency_key=idempotency_key,
|
||||
metadata=request_metadata,
|
||||
session_snapshot=filtered_snapshot,
|
||||
on_exit=LayerExitSignals(default=ExitIntent.DELETE),
|
||||
)
|
||||
|
||||
def build_for_workflow_node(self, run_input: AgentBackendWorkflowNodeRunInput) -> CreateRunRequest:
|
||||
"""Build a workflow Agent Node run request without defining another wire schema."""
|
||||
layers: list[RunLayerSpec] = []
|
||||
@ -256,32 +121,21 @@ class AgentBackendRunRequestBuilder:
|
||||
config=PromptLayerConfig(user=run_input.user_prompt),
|
||||
),
|
||||
RunLayerSpec(
|
||||
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
|
||||
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
|
||||
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LAYER_TYPE_ID,
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.execution_context,
|
||||
config=DifyPluginLayerConfig(
|
||||
tenant_id=run_input.model.tenant_id,
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
user_id=run_input.model.user_id,
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.include_history:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_HISTORY_LAYER_ID,
|
||||
type=PYDANTIC_AI_HISTORY_LAYER_TYPE_ID,
|
||||
metadata={**run_input.metadata, "origin": "agent_session_history"},
|
||||
)
|
||||
)
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
RunLayerSpec(
|
||||
name=DIFY_AGENT_MODEL_LAYER_ID,
|
||||
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=DifyPluginLLMLayerConfig(
|
||||
plugin_id=run_input.model.plugin_id,
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
@ -291,17 +145,6 @@ class AgentBackendRunRequestBuilder:
|
||||
]
|
||||
)
|
||||
|
||||
if run_input.tools is not None and run_input.tools.tools:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
name=DIFY_PLUGIN_TOOLS_LAYER_ID,
|
||||
type=DIFY_PLUGIN_TOOLS_LAYER_TYPE_ID,
|
||||
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
|
||||
metadata=run_input.metadata,
|
||||
config=run_input.tools,
|
||||
)
|
||||
)
|
||||
|
||||
if run_input.output is not None:
|
||||
layers.append(
|
||||
RunLayerSpec(
|
||||
@ -310,6 +153,7 @@ class AgentBackendRunRequestBuilder:
|
||||
metadata=run_input.metadata,
|
||||
config=DifyOutputLayerConfig(
|
||||
json_schema=run_input.output.json_schema,
|
||||
name=run_input.output.name,
|
||||
description=run_input.output.description,
|
||||
strict=run_input.output.strict,
|
||||
),
|
||||
@ -318,6 +162,7 @@ class AgentBackendRunRequestBuilder:
|
||||
|
||||
return CreateRunRequest(
|
||||
composition=RunComposition(layers=layers),
|
||||
execution_context=run_input.execution_context,
|
||||
purpose=run_input.purpose,
|
||||
idempotency_key=run_input.idempotency_key,
|
||||
metadata=run_input.metadata,
|
||||
|
||||
@ -3,13 +3,6 @@ CLI command modules extracted from `commands.py`.
|
||||
"""
|
||||
|
||||
from .account import create_tenant, reset_email, reset_password
|
||||
from .data_migrate import data_migrate, legacy_model_types
|
||||
from .data_migration import (
|
||||
export_migration_data,
|
||||
export_migration_data_template,
|
||||
import_migration_data,
|
||||
migration_data_wizard,
|
||||
)
|
||||
from .plugin import (
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
@ -32,12 +25,7 @@ from .retention import (
|
||||
restore_workflow_runs,
|
||||
)
|
||||
from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage
|
||||
from .system import (
|
||||
convert_to_agent_apps,
|
||||
fix_app_site_missing,
|
||||
reset_encrypt_key_pair,
|
||||
upgrade_db,
|
||||
)
|
||||
from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db
|
||||
from .vector import (
|
||||
add_qdrant_index,
|
||||
migrate_annotation_vector_database,
|
||||
@ -56,24 +44,18 @@ __all__ = [
|
||||
"clear_orphaned_file_records",
|
||||
"convert_to_agent_apps",
|
||||
"create_tenant",
|
||||
"data_migrate",
|
||||
"delete_archived_workflow_runs",
|
||||
"export_app_messages",
|
||||
"export_migration_data",
|
||||
"export_migration_data_template",
|
||||
"extract_plugins",
|
||||
"extract_unique_plugins",
|
||||
"file_usage",
|
||||
"fix_app_site_missing",
|
||||
"import_migration_data",
|
||||
"install_plugins",
|
||||
"install_rag_pipeline_plugins",
|
||||
"legacy_model_types",
|
||||
"migrate_annotation_vector_database",
|
||||
"migrate_data_for_plugin",
|
||||
"migrate_knowledge_vector_database",
|
||||
"migrate_oss",
|
||||
"migration_data_wizard",
|
||||
"old_metadata_migration",
|
||||
"remove_orphaned_files_on_storage",
|
||||
"reset_email",
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from services.legacy_model_type_migration import (
|
||||
VALID_TABLE_NAMES,
|
||||
LegacyModelTypeMigrationService,
|
||||
load_tenant_ids_from_file,
|
||||
)
|
||||
|
||||
_SUPPORTED_MODEL_TYPE_CHOICES = (
|
||||
ModelType.LLM.value,
|
||||
ModelType.TEXT_EMBEDDING.value,
|
||||
ModelType.RERANK.value,
|
||||
)
|
||||
_DEFAULT_CONCURRENCY = os.cpu_count() or 1
|
||||
|
||||
|
||||
def _normalize_multi_value_option(
|
||||
values: tuple[str, ...],
|
||||
*,
|
||||
valid_values: tuple[str, ...],
|
||||
option_name: str,
|
||||
) -> tuple[str, ...]:
|
||||
normalized_values: list[str] = []
|
||||
seen_values: set[str] = set()
|
||||
|
||||
for value in values:
|
||||
for item in value.split(","):
|
||||
normalized_item = item.strip()
|
||||
if not normalized_item:
|
||||
continue
|
||||
if normalized_item not in valid_values:
|
||||
raise click.BadParameter(
|
||||
f"invalid value '{normalized_item}'. valid values: {', '.join(valid_values)}",
|
||||
param_hint=option_name,
|
||||
)
|
||||
if normalized_item in seen_values:
|
||||
continue
|
||||
seen_values.add(normalized_item)
|
||||
normalized_values.append(normalized_item)
|
||||
|
||||
return tuple(normalized_values)
|
||||
|
||||
|
||||
@click.group(
|
||||
"data-migrate",
|
||||
help="Online data migration commands.",
|
||||
)
|
||||
def data_migrate() -> None:
|
||||
"""Namespace for production data migration commands."""
|
||||
|
||||
|
||||
@click.command(
|
||||
"legacy-model-types",
|
||||
help=(
|
||||
"Migrate legacy provider model_type values to canonical values. "
|
||||
"Default is dry-run and emits JSON lines only. "
|
||||
"If --tables includes provider_model_credentials, the command may also update "
|
||||
"provider_models and load_balancing_model_configs references so merged credentials stay reachable."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--apply",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Apply the migration. Default is dry-run.",
|
||||
)
|
||||
@click.option(
|
||||
"--tables",
|
||||
"tables",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Limit migration to specific tables. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: load_balancing_model_configs, provider_model_credentials, "
|
||||
"provider_model_settings, provider_models, tenant_default_models.\n\n"
|
||||
"When provider_model_credentials is selected, provider_models and "
|
||||
"load_balancing_model_configs may also be updated for credential reference rewrites.\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant tables are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--model-types",
|
||||
"model_types",
|
||||
multiple=True,
|
||||
type=str,
|
||||
help=(
|
||||
"Canonical model types to migrate. Accepts comma-separated values or repeated flags.\n"
|
||||
"\n"
|
||||
"Options: llm,text-embedding,rerank\n"
|
||||
"\n"
|
||||
"If unspecified, all relevant legacy model types are migrated."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--tenant-id-file",
|
||||
type=click.Path(exists=True, dir_okay=False, readable=True, resolve_path=True),
|
||||
help="Optional file containing tenant ids, one per line.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(dir_okay=False, resolve_path=True, path_type=Path),
|
||||
help=(
|
||||
"Optional file path for JSON lines event logs. Defaults to stdout.\n"
|
||||
"It's highly recommended to save the event logs to a file and preserve it for a period of time."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--concurrency",
|
||||
type=click.IntRange(min=1),
|
||||
default=_DEFAULT_CONCURRENCY,
|
||||
show_default=True,
|
||||
help="Number of tenant-level worker threads to run in parallel.",
|
||||
)
|
||||
def legacy_model_types(
|
||||
apply: bool,
|
||||
tables: tuple[str, ...],
|
||||
model_types: tuple[str, ...],
|
||||
tenant_id_file: str | None,
|
||||
output: Path | None,
|
||||
concurrency: int = _DEFAULT_CONCURRENCY,
|
||||
) -> None:
|
||||
"""
|
||||
Migrate legacy provider-related model_type values and emit JSON lines events.
|
||||
"""
|
||||
|
||||
normalized_tables = _normalize_multi_value_option(
|
||||
tables,
|
||||
valid_values=VALID_TABLE_NAMES,
|
||||
option_name="--tables",
|
||||
)
|
||||
normalized_model_types = _normalize_multi_value_option(
|
||||
model_types,
|
||||
valid_values=_SUPPORTED_MODEL_TYPE_CHOICES,
|
||||
option_name="--model-types",
|
||||
)
|
||||
selected_model_types = (
|
||||
tuple(ModelType.value_of(model_type) for model_type in normalized_model_types)
|
||||
if normalized_model_types
|
||||
else (
|
||||
ModelType.LLM,
|
||||
ModelType.TEXT_EMBEDDING,
|
||||
ModelType.RERANK,
|
||||
)
|
||||
)
|
||||
tenant_ids = load_tenant_ids_from_file(tenant_id_file) if tenant_id_file else None
|
||||
|
||||
output_context: AbstractContextManager[io.TextIOBase]
|
||||
if output is None:
|
||||
output_context = nullcontext(cast(io.TextIOBase, sys.stdout))
|
||||
else:
|
||||
try:
|
||||
output_context = output.open("w", encoding="utf-8")
|
||||
except OSError as exc:
|
||||
raise click.ClickException(f"failed to open output file '{output}': {exc.strerror or exc}") from exc
|
||||
|
||||
with output_context as output_stream:
|
||||
LegacyModelTypeMigrationService(
|
||||
engine=db.engine,
|
||||
apply=apply,
|
||||
concurrency=concurrency,
|
||||
output=cast(io.TextIOBase, output_stream),
|
||||
tables=normalized_tables or None,
|
||||
model_types=selected_model_types,
|
||||
tenant_ids=tenant_ids,
|
||||
).migrate()
|
||||
|
||||
|
||||
data_migrate.add_command(legacy_model_types)
|
||||
@ -1,754 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.model import App
|
||||
from models.tools import ApiToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.data_migration.dependency_discovery_service import DependencyDiscoveryService
|
||||
from services.data_migration.entities import (
|
||||
DependencyKind,
|
||||
ImportOptions,
|
||||
MigrationDataError,
|
||||
ReportContext,
|
||||
ResourceReportItem,
|
||||
)
|
||||
from services.data_migration.export_service import ExportConfigParser, MigrationExportService
|
||||
from services.data_migration.import_service import ImportRequest, MigrationImportService
|
||||
from services.data_migration.package_service import MigrationPackageService
|
||||
from services.data_migration.report_service import MigrationReportService
|
||||
|
||||
ID_STRATEGY_CHOICES = ["preserve-id", "generate-new-id"]
|
||||
CONFLICT_STRATEGY_CHOICES = ["fail", "skip", "update"]
|
||||
SUPPORTED_WIZARD_APP_MODES = ["workflow", "advanced-chat"]
|
||||
WizardToolMap = dict[str, dict[str, str | None]]
|
||||
WizardToolSelection = dict[str, list[str]]
|
||||
|
||||
|
||||
def _scripted_export_template() -> dict[str, Any]:
|
||||
return {
|
||||
"source_tenant": {
|
||||
"mode": "single",
|
||||
"id": "",
|
||||
"name": "admin's Workspace",
|
||||
},
|
||||
"apps": {
|
||||
"modes": ["workflow", "advanced-chat"],
|
||||
"ids": [],
|
||||
"all": True,
|
||||
},
|
||||
"include_referenced_tools": True,
|
||||
"additional_tools": {
|
||||
"api_tools": [],
|
||||
"workflow_tools": [],
|
||||
"mcp_tools": [],
|
||||
},
|
||||
"include_secrets": False,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": False,
|
||||
"id_strategy": "preserve-id",
|
||||
"conflict_strategy": "fail",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@click.command("app-migration-template", help="Print or write a scripted export config JSON template.")
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to write the export config JSON template. Prints to stdout when omitted.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data_template(output_file: str | None, overwrite: bool) -> None:
|
||||
template_json = json.dumps(_scripted_export_template(), indent=2, ensure_ascii=False) + "\n"
|
||||
if output_file is None:
|
||||
click.echo(template_json, nl=False)
|
||||
return
|
||||
path = Path(output_file)
|
||||
if path.exists() and not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
path.write_text(template_json)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-migration", help="Export workflow migration data to a versioned JSON package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to export config JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
"output_file",
|
||||
required=False,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--overwrite", is_flag=True, default=False, help="Overwrite output if it already exists.")
|
||||
def export_migration_data(input_file: str | None, output_file: str | None, overwrite: bool) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file), ("--output", output_file))
|
||||
assert input_file is not None
|
||||
assert output_file is not None
|
||||
raw_config = _load_json_object(input_file, "Export config")
|
||||
selection = ExportConfigParser().parse(raw_config)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
@click.command("import-app-migration", help="Import a versioned migration data package.")
|
||||
@click.option(
|
||||
"--input",
|
||||
"input_file",
|
||||
required=False,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Path to migration package JSON.",
|
||||
)
|
||||
@click.option("--target-tenant", default=None, help="Target tenant/workspace name. Overrides package metadata.")
|
||||
@click.option("--operator-email", default=None, help="Operator account email in the target tenant.")
|
||||
@click.option(
|
||||
"--id-strategy",
|
||||
default=None,
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
help="Override package ID strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--conflict-strategy",
|
||||
default=None,
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
help="Override package conflict strategy.",
|
||||
)
|
||||
@click.option(
|
||||
"--create-app-api-token-on-import/--no-create-app-api-token-on-import",
|
||||
default=None,
|
||||
help="Override package app API token creation behavior.",
|
||||
)
|
||||
def import_migration_data(
|
||||
input_file: str | None,
|
||||
target_tenant: str | None,
|
||||
operator_email: str | None,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> None:
|
||||
try:
|
||||
_require_options(("--input", input_file))
|
||||
assert input_file is not None
|
||||
package = MigrationPackageService().load_package(input_file)
|
||||
result = MigrationImportService().import_package(
|
||||
ImportRequest(
|
||||
package=package,
|
||||
cli_target_tenant=target_tenant,
|
||||
operator_email=operator_email,
|
||||
options_override=_build_options_override(
|
||||
package.metadata.import_options,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
create_app_api_token_on_import=create_app_api_token_on_import,
|
||||
),
|
||||
)
|
||||
)
|
||||
_render_report(result.report_items, context=result.report_context)
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def parse_index_selection(raw: str, values: list[str]) -> list[str]:
|
||||
normalized = raw.strip().lower()
|
||||
if normalized == "all":
|
||||
return values
|
||||
|
||||
selected: list[str] = []
|
||||
for part in raw.split(","):
|
||||
stripped = part.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
index = int(stripped)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(f"Selection must be 'all' or comma-separated numbers: {raw}") from exc
|
||||
if index < 1 or index > len(values):
|
||||
raise click.ClickException(f"Selection index out of range: {index}")
|
||||
selected.append(values[index - 1])
|
||||
return list(dict.fromkeys(selected))
|
||||
|
||||
|
||||
def _print_wizard_step(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"==== {title} ====")
|
||||
|
||||
|
||||
def _print_wizard_substep(title: str) -> None:
|
||||
click.echo("")
|
||||
click.echo(f"-- {title} --")
|
||||
|
||||
|
||||
@click.command("app-migration-wizard", help="Interactively export workflow migration data.")
|
||||
def migration_data_wizard() -> None:
|
||||
try:
|
||||
tenant = _prompt_source_tenant()
|
||||
apps = _eligible_apps_for_tenant(tenant.id)
|
||||
app_ids = _prompt_app_ids(apps)
|
||||
_print_wizard_step("Referenced Tools")
|
||||
include_referenced_tools = click.confirm(
|
||||
"Automatically export tools referenced by selected apps? [y/n, default: y]",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
auto_tools = _discover_auto_tools([app for app in apps if app.id in set(app_ids)], include_referenced_tools)
|
||||
auto_tools = _resolve_auto_tool_names(tenant.id, auto_tools)
|
||||
_print_auto_tools(auto_tools)
|
||||
additional_tools = _prompt_additional_tools(tenant.id, auto_tools)
|
||||
include_secrets, create_tokens, id_strategy, conflict_strategy = _prompt_import_options()
|
||||
_print_wizard_step("Output")
|
||||
output_file, overwrite = _prompt_output_file()
|
||||
|
||||
selection = ExportConfigParser().parse(
|
||||
{
|
||||
"source_tenant": {"mode": "single", "id": tenant.id, "name": tenant.name},
|
||||
"apps": {"ids": app_ids, "all": False},
|
||||
"include_referenced_tools": include_referenced_tools,
|
||||
"additional_tools": additional_tools,
|
||||
"include_secrets": include_secrets,
|
||||
"import_options": {
|
||||
"create_app_api_token_on_import": create_tokens,
|
||||
"id_strategy": id_strategy,
|
||||
"conflict_strategy": conflict_strategy,
|
||||
},
|
||||
}
|
||||
)
|
||||
_confirm_wizard_summary(
|
||||
tenant_name=tenant.name,
|
||||
app_names=[app.name for app in apps if app.id in set(app_ids)],
|
||||
auto_tools=auto_tools,
|
||||
additional_tools=additional_tools,
|
||||
manual_labels=_selected_tool_labels_for_tenant(tenant.id, additional_tools),
|
||||
include_referenced_tools=include_referenced_tools,
|
||||
include_secrets=include_secrets,
|
||||
create_tokens=create_tokens,
|
||||
id_strategy=id_strategy,
|
||||
conflict_strategy=conflict_strategy,
|
||||
output_file=output_file,
|
||||
)
|
||||
result = MigrationExportService().export(selection)
|
||||
MigrationPackageService().save_package(result.package, output_file, overwrite=overwrite)
|
||||
click.echo(click.style(f"Output written to {output_file}", fg="green"))
|
||||
_print_wizard_step("Report")
|
||||
_render_report(result.report_items, context=_with_output_path(result.report_context, output_file))
|
||||
except MigrationDataError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
|
||||
|
||||
def _load_json_object(path: str, label: str) -> dict[str, Any]:
|
||||
try:
|
||||
with Path(path).open(encoding="utf-8") as file:
|
||||
raw = json.load(file)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise MigrationDataError(f"{label} JSON is invalid: {exc.msg}") from exc
|
||||
if not isinstance(raw, dict):
|
||||
raise MigrationDataError(f"{label} JSON must be an object.")
|
||||
return raw
|
||||
|
||||
|
||||
def _require_options(*options: tuple[str, object | None]) -> None:
|
||||
missing_options = [name for name, value in options if value is None]
|
||||
if missing_options:
|
||||
raise click.UsageError(f"Missing option(s): {', '.join(missing_options)}.")
|
||||
|
||||
|
||||
def _build_options_override(
|
||||
package_options: ImportOptions,
|
||||
*,
|
||||
id_strategy: str | None,
|
||||
conflict_strategy: str | None,
|
||||
create_app_api_token_on_import: bool | None,
|
||||
) -> ImportOptions | None:
|
||||
if id_strategy is None and conflict_strategy is None and create_app_api_token_on_import is None:
|
||||
return None
|
||||
return ImportOptions.from_mapping(
|
||||
{
|
||||
"id_strategy": id_strategy or package_options.id_strategy,
|
||||
"conflict_strategy": conflict_strategy or package_options.conflict_strategy,
|
||||
"create_app_api_token_on_import": (
|
||||
create_app_api_token_on_import
|
||||
if create_app_api_token_on_import is not None
|
||||
else package_options.create_app_api_token_on_import
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _prompt_source_tenant() -> Tenant:
|
||||
tenants = list(db.session.scalars(sa.select(Tenant).order_by(Tenant.name.asc())).all())
|
||||
if not tenants:
|
||||
raise MigrationDataError("No tenants found.")
|
||||
|
||||
_print_wizard_step("Source Tenant")
|
||||
click.echo("Source tenants:")
|
||||
for index, tenant in enumerate(tenants, 1):
|
||||
click.echo(f"{index}. {tenant.name} ({tenant.id})")
|
||||
|
||||
tenant_index = click.prompt("Select one source tenant by number", type=int, default=1, show_default=True)
|
||||
if tenant_index < 1 or tenant_index > len(tenants):
|
||||
raise click.ClickException(f"Selection index out of range: {tenant_index}")
|
||||
return tenants[tenant_index - 1]
|
||||
|
||||
|
||||
def _eligible_apps_for_tenant(tenant_id: str) -> list[App]:
|
||||
return list(
|
||||
db.session.scalars(
|
||||
sa.select(App)
|
||||
.where(App.tenant_id == tenant_id, App.mode.in_(SUPPORTED_WIZARD_APP_MODES))
|
||||
.order_by(App.name.asc())
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def _prompt_app_ids(apps: list[App]) -> list[str]:
|
||||
if not apps:
|
||||
raise MigrationDataError("No workflow or advanced-chat apps found for the selected tenant.")
|
||||
|
||||
_print_wizard_step("App Selection")
|
||||
click.echo("Currently supported app types: workflow and chatflow.")
|
||||
click.echo("Workflow/chatflow apps:")
|
||||
for index, app in enumerate(apps, 1):
|
||||
mode = app.mode.value if hasattr(app.mode, "value") else app.mode
|
||||
click.echo(f"{index}. {app.name} [{mode}] ({app.id})")
|
||||
app_ids = parse_index_selection(
|
||||
click.prompt("Select apps by number, comma-separated numbers, or all", default="all"),
|
||||
[app.id for app in apps],
|
||||
)
|
||||
selected_apps = [app for app in apps if app.id in set(app_ids)]
|
||||
click.echo("Selected apps:")
|
||||
for app in selected_apps:
|
||||
click.echo(f"- {app.name} ({app.id})")
|
||||
return app_ids
|
||||
|
||||
|
||||
def _prompt_import_options() -> tuple[bool, bool, str, str]:
|
||||
_print_wizard_step("Import Options")
|
||||
_print_wizard_substep("Secrets")
|
||||
click.echo("Secrets include workflow/app DSL secret values, custom API tool credentials,")
|
||||
click.echo("and full MCP provider connection data such as server URL, headers, authentication, and tool list.")
|
||||
click.echo("If you choose no, credentials are omitted or masked,")
|
||||
click.echo("and MCP providers are exported as dependency metadata only.")
|
||||
click.echo("Treat the output JSON as sensitive if you choose yes.")
|
||||
include_secrets = click.confirm(
|
||||
"Include secrets in output JSON? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("App API Tokens")
|
||||
click.echo("When enabled, import will create an app API token if the imported app has none,")
|
||||
click.echo("or reuse an existing app API token if one already exists.")
|
||||
create_tokens = click.confirm(
|
||||
"Create or reuse app API tokens during import? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
_print_wizard_substep("ID Strategy")
|
||||
click.echo("ID strategy controls whether imported app and tool IDs preserve source IDs")
|
||||
click.echo("or use target-generated IDs.")
|
||||
click.echo("preserve-id: keep source IDs where the target service supports it.")
|
||||
click.echo("generate-new-id: let the target environment generate new IDs and rewrite references via mapping.")
|
||||
id_strategy = click.prompt(
|
||||
"Import ID strategy. Enter one of: preserve-id, generate-new-id",
|
||||
type=click.Choice(ID_STRATEGY_CHOICES),
|
||||
default="preserve-id",
|
||||
show_default=True,
|
||||
)
|
||||
_print_wizard_substep("Conflict Strategy")
|
||||
click.echo("Conflict strategy controls what import does when a target resource already exists.")
|
||||
click.echo("fail: stop at the first conflict; previously committed resources are not rolled back.")
|
||||
click.echo("skip: keep the existing target resource and skip importing that resource.")
|
||||
click.echo("update: update the existing target resource in place.")
|
||||
conflict_strategy = click.prompt(
|
||||
"Import conflict strategy. Enter one of: fail, skip, update",
|
||||
type=click.Choice(CONFLICT_STRATEGY_CHOICES),
|
||||
default="update",
|
||||
show_default=True,
|
||||
)
|
||||
return include_secrets, create_tokens, id_strategy, conflict_strategy
|
||||
|
||||
|
||||
def _discover_auto_tools(apps: list[App], include_referenced_tools: bool) -> WizardToolMap:
|
||||
auto_tools: WizardToolMap = {"api_tools": {}, "workflow_tools": {}, "mcp_tools": {}}
|
||||
if not include_referenced_tools:
|
||||
return auto_tools
|
||||
discovery_service = DependencyDiscoveryService()
|
||||
for app in apps:
|
||||
dsl_content = AppDslService.export_dsl(app_model=app, include_secret=False)
|
||||
raw_dsl = yaml.safe_load(dsl_content) if dsl_content else {}
|
||||
dsl = raw_dsl if isinstance(raw_dsl, dict) else {}
|
||||
for dependency in discovery_service.discover_from_dsl(dsl):
|
||||
if dependency.kind == DependencyKind.API_TOOL:
|
||||
auto_tools["api_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
elif dependency.kind == DependencyKind.WORKFLOW_TOOL:
|
||||
auto_tools["workflow_tools"][dependency.provider_name or dependency.provider_id] = (
|
||||
dependency.provider_id
|
||||
)
|
||||
elif dependency.kind == DependencyKind.MCP_TOOL:
|
||||
auto_tools["mcp_tools"][dependency.provider_name or dependency.provider_id] = dependency.provider_id
|
||||
return auto_tools
|
||||
|
||||
|
||||
def _resolve_auto_tool_names(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolMap:
|
||||
return {
|
||||
"api_tools": _resolve_api_tool_names(tenant_id, auto_tools["api_tools"]),
|
||||
"workflow_tools": _resolve_workflow_tool_names(tenant_id, auto_tools["workflow_tools"]),
|
||||
"mcp_tools": _resolve_mcp_tool_names(tenant_id, auto_tools["mcp_tools"]),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_api_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [ApiToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(ApiToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(ApiToolProvider).where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_workflow_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [WorkflowToolProvider.name == name]
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(WorkflowToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_mcp_tool_names(tenant_id: str, tools: dict[str, str | None]) -> dict[str, str | None]:
|
||||
resolved: dict[str, str | None] = {}
|
||||
for name, identifier in tools.items():
|
||||
predicates = [MCPToolProvider.name == name]
|
||||
if identifier:
|
||||
predicates.append(MCPToolProvider.server_identifier == identifier)
|
||||
if _is_uuid_string(identifier):
|
||||
predicates.append(MCPToolProvider.id == identifier)
|
||||
provider = db.session.scalar(
|
||||
sa.select(MCPToolProvider).where(
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
sa.or_(*predicates),
|
||||
)
|
||||
)
|
||||
resolved[provider.name if provider else name] = provider.id if provider else identifier
|
||||
return resolved
|
||||
|
||||
|
||||
def _is_uuid_string(value: str | None) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
try:
|
||||
UUID(value)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_auto_tools(auto_tools: WizardToolMap) -> None:
|
||||
_print_wizard_step("Automatically Discovered Tools")
|
||||
click.echo("Automatically discovered tools:")
|
||||
_print_auto_tool_category("Custom API tools", auto_tools["api_tools"])
|
||||
_print_auto_tool_category("Workflow tools", auto_tools["workflow_tools"])
|
||||
_print_auto_tool_category("MCP tools", auto_tools["mcp_tools"])
|
||||
|
||||
|
||||
def _print_auto_tool_category(label: str, values: dict[str, str | None]) -> None:
|
||||
click.echo(label)
|
||||
if not values:
|
||||
click.echo("- none")
|
||||
return
|
||||
for name, identifier in sorted(values.items()):
|
||||
click.echo(f"- {_format_tool_name_id(name, identifier)}")
|
||||
|
||||
|
||||
def _prompt_additional_tools(tenant_id: str, auto_tools: WizardToolMap) -> WizardToolSelection:
|
||||
selections: WizardToolSelection = {"api_tools": [], "workflow_tools": [], "mcp_tools": []}
|
||||
_print_wizard_step("Additional Tools")
|
||||
if not click.confirm(
|
||||
"Export additional tools manually? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
):
|
||||
_print_final_tool_selection(auto_tools, selections, {})
|
||||
return selections
|
||||
manual_labels: dict[str, str] = {}
|
||||
api_tool_options = [
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["api_tools"] = _prompt_tool_category(
|
||||
"Custom API tools",
|
||||
api_tool_options,
|
||||
auto_tools=auto_tools["api_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(api_tool_options, selections["api_tools"]))
|
||||
workflow_tool_options = [
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["workflow_tools"] = _prompt_tool_category(
|
||||
"Workflow tools",
|
||||
workflow_tool_options,
|
||||
auto_tools=auto_tools["workflow_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(workflow_tool_options, selections["workflow_tools"]))
|
||||
mcp_tool_options = [
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider).where(MCPToolProvider.tenant_id == tenant_id).order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
]
|
||||
selections["mcp_tools"] = _prompt_tool_category(
|
||||
"MCP tools",
|
||||
mcp_tool_options,
|
||||
auto_tools=auto_tools["mcp_tools"],
|
||||
)
|
||||
manual_labels.update(_selected_tool_labels(mcp_tool_options, selections["mcp_tools"]))
|
||||
_print_final_tool_selection(auto_tools, selections, manual_labels)
|
||||
return selections
|
||||
|
||||
|
||||
def _selected_tool_labels_for_tenant(tenant_id: str, selected_tools: WizardToolSelection) -> dict[str, str]:
|
||||
labels: dict[str, str] = {}
|
||||
if selected_tools["api_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.name, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id)
|
||||
.order_by(ApiToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["api_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["workflow_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.id)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
.order_by(WorkflowToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["workflow_tools"],
|
||||
)
|
||||
)
|
||||
if selected_tools["mcp_tools"]:
|
||||
labels.update(
|
||||
_selected_tool_labels(
|
||||
[
|
||||
(tool.id, tool.name, tool.server_identifier)
|
||||
for tool in db.session.scalars(
|
||||
sa.select(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id)
|
||||
.order_by(MCPToolProvider.name)
|
||||
).all()
|
||||
],
|
||||
selected_tools["mcp_tools"],
|
||||
)
|
||||
)
|
||||
return labels
|
||||
|
||||
|
||||
def _selected_tool_labels(options: list[tuple[str, str, str]], selected_values: list[str]) -> dict[str, str]:
|
||||
selected = set(selected_values)
|
||||
return {value: _format_tool_name_id(name, detail) for value, name, detail in options if value in selected}
|
||||
|
||||
|
||||
def _prompt_tool_category(
|
||||
label: str,
|
||||
options: list[tuple[str, str, str]],
|
||||
*,
|
||||
auto_tools: dict[str, str | None],
|
||||
) -> list[str]:
|
||||
if not options:
|
||||
click.echo(f"{label}: none")
|
||||
return []
|
||||
_print_wizard_step(label)
|
||||
for index, (value, name, detail) in enumerate(options, 1):
|
||||
marker = "[auto]" if _is_auto_tool(value, name, detail, auto_tools) else "[ ]"
|
||||
click.echo(f"{index}. {marker} {name} ({detail})")
|
||||
raw = click.prompt(
|
||||
f"Select {label.lower()} by number, comma-separated numbers, all, or empty",
|
||||
default="",
|
||||
show_default=cast(Any, "empty"),
|
||||
)
|
||||
if not raw.strip():
|
||||
return []
|
||||
return parse_index_selection(raw, [value for value, _, _ in options])
|
||||
|
||||
|
||||
def _is_auto_tool(value: str, name: str, detail: str, auto_tools: dict[str, str | None]) -> bool:
|
||||
return name in auto_tools or value in auto_tools or value in auto_tools.values() or detail in auto_tools.values()
|
||||
|
||||
|
||||
def _print_final_tool_selection(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
_print_wizard_step("Final Tool Selection")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
|
||||
|
||||
def _print_tool_selection_body(
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo("Final tools to export:")
|
||||
_print_final_tool_category(
|
||||
"Custom API tools",
|
||||
auto_tools["api_tools"],
|
||||
additional_tools["api_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category(
|
||||
"Workflow tools",
|
||||
auto_tools["workflow_tools"],
|
||||
additional_tools["workflow_tools"],
|
||||
manual_labels,
|
||||
)
|
||||
_print_final_tool_category("MCP tools", auto_tools["mcp_tools"], additional_tools["mcp_tools"], manual_labels)
|
||||
|
||||
|
||||
def _print_final_tool_category(
|
||||
label: str,
|
||||
auto_tools: dict[str, str | None],
|
||||
manual_values: list[str],
|
||||
manual_labels: dict[str, str],
|
||||
) -> None:
|
||||
click.echo(label)
|
||||
lines = [f"- [auto] {_format_tool_name_id(name, identifier)}" for name, identifier in sorted(auto_tools.items())]
|
||||
auto_identifiers = {identifier for identifier in auto_tools.values() if identifier}
|
||||
lines.extend(
|
||||
f"- [manual] {manual_labels.get(value, value)}"
|
||||
for value in manual_values
|
||||
if value not in auto_tools and value not in auto_identifiers
|
||||
)
|
||||
if not lines:
|
||||
click.echo("- none")
|
||||
return
|
||||
for line in lines:
|
||||
click.echo(line)
|
||||
|
||||
|
||||
def _format_tool_name_id(name: str, identifier: str | None) -> str:
|
||||
if identifier and identifier != name:
|
||||
return f"{name}: {identifier}"
|
||||
return name
|
||||
|
||||
|
||||
def _confirm_wizard_summary(
|
||||
*,
|
||||
tenant_name: str,
|
||||
app_names: list[str],
|
||||
auto_tools: WizardToolMap,
|
||||
additional_tools: WizardToolSelection,
|
||||
manual_labels: dict[str, str],
|
||||
include_referenced_tools: bool,
|
||||
include_secrets: bool,
|
||||
create_tokens: bool,
|
||||
id_strategy: str,
|
||||
conflict_strategy: str,
|
||||
output_file: str,
|
||||
) -> None:
|
||||
_print_wizard_step("Summary")
|
||||
click.echo("Migration export summary:")
|
||||
click.echo(f"source tenant: {tenant_name}")
|
||||
click.echo(f"selected apps: {len(app_names)}")
|
||||
for app_name in app_names:
|
||||
click.echo(f"- {app_name}")
|
||||
click.echo(f"auto referenced tools: {str(include_referenced_tools).lower()}")
|
||||
_print_tool_selection_body(auto_tools, additional_tools, manual_labels)
|
||||
click.echo(f"include secrets: {str(include_secrets).lower()}")
|
||||
click.echo(f"create app api token on import: {str(create_tokens).lower()}")
|
||||
click.echo(f"id strategy: {id_strategy}")
|
||||
click.echo(f"conflict strategy: {conflict_strategy}")
|
||||
click.echo(f"output path: {output_file}")
|
||||
if not click.confirm("Write migration package? [y/n, default: y]", default=True, show_default=False):
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _prompt_output_file() -> tuple[str, bool]:
|
||||
default_output = f"migration-data-{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
output_file = click.prompt("Output path", default=default_output, show_default=True)
|
||||
if output_file.lower() in {"y", "yes", "n", "no"}:
|
||||
raise click.ClickException("Output path must be a file path. Press Enter to use the default path.")
|
||||
overwrite = False
|
||||
if Path(output_file).exists():
|
||||
overwrite = click.confirm(
|
||||
"Output file exists. Overwrite? [y/n, default: n]",
|
||||
default=False,
|
||||
show_default=False,
|
||||
)
|
||||
if not overwrite:
|
||||
raise click.ClickException(f"Output file already exists: {output_file}")
|
||||
return output_file, overwrite
|
||||
|
||||
|
||||
def _with_output_path(context: ReportContext | None, output_path: str) -> ReportContext:
|
||||
if context is None:
|
||||
return ReportContext(output_path=output_path)
|
||||
return ReportContext(
|
||||
output_path=output_path,
|
||||
source_scope=context.source_scope,
|
||||
selected_app_count=context.selected_app_count,
|
||||
include_secrets=context.include_secrets,
|
||||
target_tenant=context.target_tenant,
|
||||
operator_email=context.operator_email,
|
||||
app_api_tokens_created=context.app_api_tokens_created,
|
||||
app_api_tokens_reused=context.app_api_tokens_reused,
|
||||
id_mapping_count=context.id_mapping_count,
|
||||
id_mappings=context.id_mappings,
|
||||
)
|
||||
|
||||
|
||||
def _render_report(report_items: list[ResourceReportItem], *, context: ReportContext | None = None) -> None:
|
||||
for line in MigrationReportService().render(report_items, context=context):
|
||||
click.echo(line)
|
||||
@ -30,7 +30,7 @@ def vdb_migrate(scope: str):
|
||||
|
||||
def migrate_annotation_vector_database():
|
||||
"""
|
||||
Migrate annotation data to target vector database.
|
||||
Migrate annotation datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Starting annotation data migration.", fg="green"))
|
||||
create_count = 0
|
||||
@ -140,7 +140,7 @@ def migrate_annotation_vector_database():
|
||||
|
||||
def migrate_knowledge_vector_database():
|
||||
"""
|
||||
Migrate vector database data to target vector database.
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style("Starting vector database migration.", fg="green"))
|
||||
create_count = 0
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -25,7 +23,7 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
|
||||
EDITION: str = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
@ -525,44 +525,6 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
OPENAPI_ENABLED: bool = Field(
|
||||
description=(
|
||||
"Enable the /openapi/v1/* endpoint group used by difyctl and other "
|
||||
"programmatic clients. Set to true to activate; disabled by default."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_ENABLED"),
|
||||
default=False,
|
||||
)
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -938,17 +900,6 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1235,14 +1186,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
@ -41,21 +41,3 @@ class MilvusConfig(BaseSettings):
|
||||
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="Enable TLS for the Milvus connection (one-way TLS). When True, the client uses gRPC over TLS "
|
||||
"and verifies the server certificate. Equivalent to passing secure=True to pymilvus.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_PEM_PATH: str | None = Field(
|
||||
description="Filesystem path inside the container to the Milvus server certificate (PEM). Mount this via "
|
||||
"a Kubernetes secret. Used as pymilvus's server_pem_path when MILVUS_SECURE is True.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SERVER_NAME: str | None = Field(
|
||||
description="Server name (TLS SNI / certificate CN or SAN) to verify against the Milvus server certificate. "
|
||||
"Required when MILVUS_SERVER_PEM_PATH is set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@ -68,7 +68,6 @@ from .app import (
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_node_output_inspector,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
@ -219,7 +218,6 @@ __all__ = [
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_node_output_inspector",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
|
||||
@ -5,7 +5,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from services.agent.composer_service import AgentComposerService
|
||||
from services.agent.composer_validator import ComposerConfigValidator
|
||||
from services.entities.agent_entities import ComposerSavePayload
|
||||
@ -19,7 +19,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
def get(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_workflow_composer(
|
||||
tenant_id=tenant_id,
|
||||
@ -33,7 +33,7 @@ class WorkflowAgentComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
def put(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
@ -52,7 +52,7 @@ class WorkflowAgentComposerValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, app_model, node_id: str):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
@ -64,7 +64,7 @@ class WorkflowAgentComposerCandidatesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
def get(self, app_model, node_id: str):
|
||||
return AgentComposerService.get_workflow_candidates(app_id=app_model.id)
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class WorkflowAgentComposerImpactApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, app_model, node_id: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
current_snapshot_id = payload.binding.current_snapshot_id if payload.binding else None
|
||||
@ -91,7 +91,7 @@ class WorkflowAgentComposerSaveToRosterApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, app_model, node_id: str):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_workflow_composer(
|
||||
@ -109,7 +109,7 @@ class AgentAppComposerApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return AgentComposerService.load_agent_app_composer(tenant_id=tenant_id, app_id=app_model.id)
|
||||
|
||||
@ -119,7 +119,7 @@ class AgentAppComposerApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model()
|
||||
def put(self, app_model: App):
|
||||
def put(self, app_model):
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
return AgentComposerService.save_agent_app_composer(
|
||||
@ -137,7 +137,7 @@ class AgentAppComposerValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
payload = ComposerSavePayload.model_validate(console_ns.payload or {})
|
||||
ComposerConfigValidator.validate_save_payload(payload)
|
||||
return {"result": "success", "errors": []}
|
||||
@ -149,5 +149,5 @@ class AgentAppComposerCandidatesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
return AgentComposerService.get_agent_app_candidates(app_id=app_model.id)
|
||||
|
||||
@ -9,25 +9,18 @@ from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.common.schema import register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.enums import ApiTokenType
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyItem(ResponseModel):
|
||||
@ -47,7 +40,7 @@ class ApiKeyList(ResponseModel):
|
||||
data: list[ApiKeyItem]
|
||||
|
||||
|
||||
register_response_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
@ -71,11 +64,10 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
def get(self, resource_id: str, current_tenant_id: str) -> dict[str, object]:
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(resource_id, current_tenant_id))
|
||||
|
||||
def _get_api_key_list(self, resource_id: str, current_tenant_id: str) -> ApiKeyList:
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
@ -83,14 +75,13 @@ class BaseApiKeyListResource(Resource):
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
)
|
||||
).all()
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True)
|
||||
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
@edit_permission_required
|
||||
def post(self, resource_id: str, current_tenant_id: str) -> tuple[dict[str, object], int]:
|
||||
return dump_response(ApiKeyItem, self._create_api_key(resource_id, current_tenant_id)), 201
|
||||
|
||||
def _create_api_key(self, resource_id: str, current_tenant_id: str) -> ApiToken:
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
current_key_count: int = (
|
||||
db.session.scalar(
|
||||
@ -117,7 +108,7 @@ class BaseApiKeyListResource(Resource):
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
db.session.commit()
|
||||
return api_token
|
||||
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
class BaseApiKeyResource(Resource):
|
||||
@ -127,20 +118,9 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(
|
||||
self, resource_id: str, api_key_id: str, current_tenant_id: str, current_user: Account
|
||||
) -> tuple[str, int]:
|
||||
self._delete_api_key(resource_id, api_key_id, current_tenant_id, current_user)
|
||||
return "", 204
|
||||
|
||||
def _delete_api_key(
|
||||
self,
|
||||
resource_id: str,
|
||||
api_key_id: str,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> None:
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
@ -167,6 +147,8 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@ -174,21 +156,18 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
def get(self, resource_id: UUID):
|
||||
"""Get all API keys for an app"""
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
def post(self, resource_id: UUID):
|
||||
"""Create a new API key for an app"""
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -202,14 +181,9 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
"""Delete an API key for an app"""
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
|
||||
resource_type = ApiTokenType.APP
|
||||
resource_model = App
|
||||
@ -222,21 +196,18 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, resource_id: UUID) -> dict[str, object]:
|
||||
def get(self, resource_id: UUID):
|
||||
"""Get all API keys for a dataset"""
|
||||
return dump_response(ApiKeyList, self._get_api_key_list(str(resource_id), current_tenant_id))
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
@with_current_tenant_id
|
||||
@edit_permission_required
|
||||
def post(self, current_tenant_id: str, resource_id: UUID) -> tuple[dict[str, object], int]:
|
||||
def post(self, resource_id: UUID):
|
||||
"""Create a new API key for a dataset"""
|
||||
return dump_response(ApiKeyItem, self._create_api_key(str(resource_id), current_tenant_id)), 201
|
||||
return super().post(resource_id)
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
@ -250,14 +221,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, resource_id: UUID, api_key_id: UUID
|
||||
) -> tuple[str, int]:
|
||||
def delete(self, resource_id: UUID, api_key_id: UUID):
|
||||
"""Delete an API key for a dataset"""
|
||||
self._delete_api_key(str(resource_id), str(api_key_id), current_tenant_id, current_user)
|
||||
return "", 204
|
||||
return super().delete(str(resource_id), str(api_key_id))
|
||||
|
||||
resource_type = ApiTokenType.DATASET
|
||||
resource_model = Dataset
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ class AgentLogApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
args = AgentLogQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from controllers.common.fields import RedirectUrlResponse, SimpleResultResponse
|
||||
from controllers.common.helpers import FileInfo
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model, with_session
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.workspace.models import LoadBalancingPayload
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -26,6 +26,7 @@ from controllers.console.wraps import (
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@ -573,7 +574,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model(mode=None)
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
|
||||
@ -581,7 +582,7 @@ class AppApi(Resource):
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id))
|
||||
app_model.access_mode = app_setting.access_mode # type: ignore[attr-defined]
|
||||
app_model.access_mode = app_setting.access_mode
|
||||
|
||||
response_model = AppDetailWithSite.model_validate(app_model, from_attributes=True)
|
||||
return response_model.model_dump(mode="json")
|
||||
@ -598,7 +599,7 @@ class AppApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App):
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
|
||||
@ -627,7 +628,7 @@ class AppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App):
|
||||
def delete(self, app_model):
|
||||
"""Delete app"""
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
@ -648,7 +649,7 @@ class AppCopyApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -709,7 +710,7 @@ class AppExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -731,7 +732,7 @@ class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
@ -762,7 +763,7 @@ class AppNameApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -784,7 +785,7 @@ class AppIconApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_service = AppService()
|
||||
@ -811,7 +812,7 @@ class AppSiteStatus(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -833,7 +834,7 @@ class AppApiStatus(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
|
||||
app_service = AppService()
|
||||
@ -851,11 +852,11 @@ class AppTraceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_session
|
||||
@get_app_model
|
||||
def get(self, session: Session, app_model: App):
|
||||
def get(self, app_model):
|
||||
"""Get app trace"""
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
with session_factory.create_session() as session:
|
||||
app_trace_config = OpsTraceManager.get_app_tracing_config(app_model.id, session)
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@ -874,7 +875,7 @@ class AppTraceApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ class ChatMessageAudioApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
@ -171,7 +171,7 @@ class TextModesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
try:
|
||||
args = TextToSpeechVoiceQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -84,7 +84,7 @@ class CompletionMessageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App, task_id: str):
|
||||
def post(self, app_model, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@ -159,7 +159,7 @@ class ChatMessageApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
|
||||
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
def post(self, app_model, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ from fields.conversation_fields import (
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
@ -93,7 +93,7 @@ class CompletionConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -134,7 +134,7 @@ class CompletionConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
.distinct()
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
query = (
|
||||
@ -165,7 +165,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationMessageDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
@ -182,7 +182,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
@ -207,7 +207,7 @@ class ChatConversationApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -272,7 +272,7 @@ class ChatConversationApi(Resource):
|
||||
.join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
.group_by(Conversation.id)
|
||||
.distinct()
|
||||
)
|
||||
case "not_annotated":
|
||||
query = (
|
||||
@ -318,7 +318,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, conversation_id: UUID):
|
||||
def get(self, app_model, conversation_id: UUID):
|
||||
conversation_id_str = str(conversation_id)
|
||||
return ConversationDetailResponse.model_validate(
|
||||
_get_conversation(app_model, conversation_id_str), from_attributes=True
|
||||
@ -335,7 +335,7 @@ class ChatConversationDetailApi(Resource):
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, conversation_id: UUID):
|
||||
def delete(self, app_model, conversation_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation_id_str = str(conversation_id)
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
@ -94,7 +94,7 @@ class ConversationVariablesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
stmt = (
|
||||
|
||||
@ -11,7 +11,7 @@ from controllers.console.app.error import (
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -22,7 +22,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
@ -64,9 +64,9 @@ class RuleGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
@ -93,9 +93,9 @@ class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
@ -125,9 +125,9 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
@ -157,9 +157,9 @@ class InstructionGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
|
||||
@ -11,18 +11,13 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import AppMCPServerStatus
|
||||
from models.model import App, AppMCPServer
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
@ -78,7 +73,7 @@ class AppMCPServerController(Resource):
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
|
||||
if server is None:
|
||||
return {}
|
||||
@ -97,8 +92,8 @@ class AppMCPServerController(Resource):
|
||||
@login_required
|
||||
@setup_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, app_model: App):
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = MCPServerCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
description = payload.description
|
||||
@ -132,7 +127,7 @@ class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App):
|
||||
def put(self, app_model):
|
||||
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
|
||||
server = db.session.get(AppMCPServer, payload.id)
|
||||
if not server:
|
||||
@ -168,8 +163,8 @@ class AppMCPServerRefreshController(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, server_id: UUID):
|
||||
def get(self, server_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
server = db.session.scalar(
|
||||
select(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id, AppMCPServer.tenant_id == current_tenant_id)
|
||||
|
||||
@ -45,7 +45,7 @@ from libs.helper import to_timestamp, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
@ -180,7 +180,7 @@ class ChatMessageListApi(Resource):
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict())
|
||||
|
||||
conversation = db.session.scalar(
|
||||
@ -257,7 +257,7 @@ class MessageFeedbackApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
@ -314,7 +314,7 @@ class MessageAnnotationCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
count = db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id)
|
||||
)
|
||||
@ -337,7 +337,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
def get(self, app_model, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -379,7 +379,7 @@ class MessageFeedbackExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Import the service function
|
||||
@ -417,7 +417,7 @@ class MessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App, message_id: UUID):
|
||||
def get(self, app_model, message_id: str):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
|
||||
@ -16,7 +16,7 @@ from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class ModelConfigResource(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# validate config
|
||||
|
||||
@ -20,7 +20,6 @@ from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppSiteUpdatePayload(BaseModel):
|
||||
@ -85,7 +84,7 @@ class AppSite(Resource):
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
@ -134,7 +133,7 @@ class AppSiteAccessTokenReset(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
from models.model import App
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
@ -48,7 +47,7 @@ class DailyMessageStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -62,12 +61,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -109,7 +104,7 @@ class DailyConversationStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -123,12 +118,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -169,7 +160,7 @@ class DailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -183,12 +174,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -230,7 +217,7 @@ class DailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -245,12 +232,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -294,7 +277,7 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -316,12 +299,8 @@ FROM
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -374,7 +353,7 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -392,12 +371,8 @@ LEFT JOIN
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -444,7 +419,7 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -458,12 +433,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
@ -505,7 +476,7 @@ class TokensPerSecondStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -521,12 +492,8 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict: dict[str, object] = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"invoke_from": InvokeFrom.DEBUGGER,
|
||||
}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -83,14 +82,13 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
match value:
|
||||
case FileSegment():
|
||||
file = value.value
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
case ArrayFileSegment():
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
@ -347,15 +345,14 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -366,7 +363,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -393,11 +390,10 @@ class VariableApi(Resource):
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -438,15 +434,14 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -462,7 +457,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -473,11 +468,10 @@ class VariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id_str,
|
||||
variable_id=variable_id,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
|
||||
@ -1,415 +0,0 @@
|
||||
"""Console REST endpoints for the Node Output Inspector (Stage 4 §8 / §10.3).
|
||||
|
||||
PRD §Node Output Inspector replaces the consumer-organized Variable Inspector
|
||||
with a producer-organized view of each node's declared outputs and their
|
||||
per-run status. This module exposes two parallel sets of three read-only
|
||||
endpoints — one for ``/workflows/draft/runs/...`` (Composer test runs) and one
|
||||
for ``/workflows/published/runs/...`` (real App API / webapp / webhook /
|
||||
schedule / plugin triggers). Both sets share the same service code, the same
|
||||
response shapes, and the same error codes; the URL is the *only* difference,
|
||||
so the frontend can pick the right prefix based on which run-detail page the
|
||||
user is on.
|
||||
|
||||
Decision D-1 (published Inspector deferred) was lifted 2026-05-26 — the
|
||||
``published_run_inspector_not_implemented`` 404 code is therefore no longer
|
||||
produced.
|
||||
|
||||
URLs follow the design doc and reuse the existing
|
||||
``/apps/<uuid:app_id>/workflows/draft/...`` prefix from
|
||||
:mod:`controllers.console.app.workflow_draft_variable`. The
|
||||
``published`` prefix mirrors it shape-for-shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from services.workflow import inspector_events
|
||||
from services.workflow.node_output_inspector_service import (
|
||||
NodeOutputInspectorError,
|
||||
NodeOutputInspectorService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Heartbeat cadence — every N empty subscribe ticks emit a SSE comment so
|
||||
# intervening proxies (nginx, ingress) don't reap the idle connection.
|
||||
# ``inspector_events.subscribe`` ticks at 1s, so 15 → 15s heartbeat.
|
||||
_HEARTBEAT_EVERY_TICKS = 15
|
||||
# Hard ceiling on a single stream — if we never see a terminal workflow
|
||||
# event (engine crashed, redis dropped the message), force-close after this
|
||||
# many ticks (= seconds).
|
||||
_STREAM_HARD_TIMEOUT_TICKS = 1800 # 30 min
|
||||
|
||||
|
||||
def _service() -> NodeOutputInspectorService:
|
||||
"""One-line factory so tests can monkeypatch a stub if needed."""
|
||||
return NodeOutputInspectorService()
|
||||
|
||||
|
||||
def _serve_snapshot(app_model: App, run_id: UUID) -> dict:
|
||||
"""Resource-body shared by draft + published snapshot endpoints.
|
||||
|
||||
Pulled out so the 6 REST routes don't duplicate the same 6-line try/except
|
||||
+ ``model_dump`` ritual — the routes shrink to one-liners and the actual
|
||||
behaviour lives here, where unit tests can hit it without spinning up
|
||||
Flask request context.
|
||||
"""
|
||||
try:
|
||||
snapshot = _service().snapshot_workflow_run(app_model=app_model, workflow_run_id=str(run_id))
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return snapshot.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_node_detail(app_model: App, run_id: UUID, node_id: str) -> dict:
|
||||
"""Resource-body shared by draft + published node-detail endpoints."""
|
||||
try:
|
||||
view = _service().node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return view.model_dump(mode="json")
|
||||
|
||||
|
||||
def _serve_output_preview(app_model: App, run_id: UUID, node_id: str, output_name: str) -> dict:
|
||||
"""Resource-body shared by draft + published output-preview endpoints."""
|
||||
try:
|
||||
preview = _service().output_preview(
|
||||
app_model=app_model,
|
||||
workflow_run_id=str(run_id),
|
||||
node_id=node_id,
|
||||
output_name=output_name,
|
||||
)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
return preview.model_dump(mode="json")
|
||||
|
||||
|
||||
class _InspectorNotFound(BaseHTTPException):
|
||||
"""404 that preserves the inspector's specific error code.
|
||||
|
||||
Without this the response body collapses to a generic ``not_found`` code
|
||||
and clients lose the ability to distinguish, e.g.,
|
||||
``workflow_run_not_found`` from ``published_run_inspector_not_implemented``.
|
||||
"""
|
||||
|
||||
code = 404
|
||||
|
||||
def __init__(self, error: NodeOutputInspectorError) -> None:
|
||||
self.error_code = error.code
|
||||
super().__init__(description=str(error))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowDraftRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot organized by producer node."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowDraftRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a draft workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowDraftRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output (with signed URL for file refs)."""
|
||||
|
||||
@console_ns.doc("get_workflow_draft_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output, including signed download URL for files.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# SSE event stream — shared generator used by draft + published variants
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _sse_envelope(event: str, data: dict | str, event_id: int) -> str:
|
||||
"""Format one SSE record per D-5 ``{event, data, id}`` envelope.
|
||||
|
||||
``data`` is JSON-serialized when given as a dict; raw strings are
|
||||
forwarded unchanged so we can also emit ``:keepalive`` comment lines.
|
||||
"""
|
||||
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event}\nid: {event_id}\ndata: {payload}\n\n"
|
||||
|
||||
|
||||
def _stream_inspector_events(app_model: App, run_id: UUID) -> Iterator[str]:
|
||||
"""Yield SSE-framed strings for one workflow run.
|
||||
|
||||
The stream begins with a full ``snapshot`` event so the client has a
|
||||
starting state without needing a separate REST GET. Then for every
|
||||
``node_changed`` message from the pub/sub channel we re-read that node
|
||||
from DB and push a fresh ``node_changed`` event. When the workflow run
|
||||
reaches a terminal state we push one final ``workflow_run_completed``
|
||||
event and close the stream.
|
||||
|
||||
Failures inside the loop are caught and surfaced as ``error`` events so
|
||||
the frontend can show a banner rather than seeing the connection drop
|
||||
silently. The Inspector never raises across the SSE boundary.
|
||||
"""
|
||||
service = _service()
|
||||
run_id_str = str(run_id)
|
||||
|
||||
# Initial snapshot — also flushes a 404 back at the client right away
|
||||
# if the run is gone (raised before yielding any bytes, so Flask turns it
|
||||
# into the normal HTTP 404 path).
|
||||
try:
|
||||
snapshot = service.snapshot_workflow_run(app_model=app_model, workflow_run_id=run_id_str)
|
||||
except NodeOutputInspectorError as error:
|
||||
raise _InspectorNotFound(error) from error
|
||||
|
||||
event_id = 0
|
||||
yield _sse_envelope("snapshot", snapshot.model_dump(mode="json"), event_id)
|
||||
|
||||
# If the run already finished by the time the client connected, emit
|
||||
# the terminal envelope synchronously and close — no point subscribing.
|
||||
# The enum value for partial success is the hyphenated ``partial-succeeded``
|
||||
# (graphon.enums.WorkflowExecutionStatus), not ``partial_succeeded``.
|
||||
if snapshot.workflow_run_status.value in {"succeeded", "failed", "stopped", "partial-succeeded"}:
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": snapshot.workflow_run_status.value},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Live subscription
|
||||
ticks_since_heartbeat = 0
|
||||
total_ticks = 0
|
||||
for message in inspector_events.subscribe(run_id_str, timeout_seconds=1.0):
|
||||
total_ticks += 1
|
||||
if total_ticks > _STREAM_HARD_TIMEOUT_TICKS:
|
||||
logger.warning(
|
||||
"Inspector SSE: forcing close after %ds without terminal event for run %s",
|
||||
_STREAM_HARD_TIMEOUT_TICKS,
|
||||
run_id_str,
|
||||
)
|
||||
return
|
||||
|
||||
# Heartbeat sentinel — ``inspector_events.subscribe`` synthesizes a
|
||||
# ``node_changed`` message with both fields ``None`` on every redis
|
||||
# timeout. Real ``workflow_completed`` messages keep their kind even
|
||||
# when status couldn't be resolved (publisher race), so checking kind
|
||||
# first makes the heartbeat branch safe.
|
||||
if message.kind == "node_changed" and message.node_id is None and message.status is None:
|
||||
ticks_since_heartbeat += 1
|
||||
if ticks_since_heartbeat >= _HEARTBEAT_EVERY_TICKS:
|
||||
yield ":keepalive\n\n"
|
||||
ticks_since_heartbeat = 0
|
||||
continue
|
||||
ticks_since_heartbeat = 0
|
||||
|
||||
if message.kind == "workflow_completed":
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"workflow_run_completed",
|
||||
{"workflow_run_id": run_id_str, "workflow_run_status": message.status or "unknown"},
|
||||
event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# node_changed: recompute the node slice from DB
|
||||
if not message.node_id:
|
||||
continue
|
||||
try:
|
||||
node_view = service.node_detail(
|
||||
app_model=app_model,
|
||||
workflow_run_id=run_id_str,
|
||||
node_id=message.node_id,
|
||||
)
|
||||
except NodeOutputInspectorError:
|
||||
# Node may not appear in the graph yet (race with persistence); skip.
|
||||
continue
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Inspector SSE: node_detail failed for run %s node %s",
|
||||
run_id_str,
|
||||
message.node_id,
|
||||
exc_info=True,
|
||||
)
|
||||
event_id += 1
|
||||
yield _sse_envelope(
|
||||
"error",
|
||||
{"node_id": message.node_id, "message": "failed to refresh node detail"},
|
||||
event_id,
|
||||
)
|
||||
continue
|
||||
|
||||
event_id += 1
|
||||
yield _sse_envelope("node_changed", node_view.model_dump(mode="json"), event_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowDraftRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a draft run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_draft_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a draft workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Published-run endpoints — symmetric to the draft trio above
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs")
|
||||
class WorkflowPublishedRunNodeOutputsApi(Resource):
|
||||
"""Whole-run snapshot for a *published* workflow run.
|
||||
|
||||
Same response shape as the ``/draft/`` variant — frontend can multiplex
|
||||
based on which page (Composer test-run vs. Run History) is mounted.
|
||||
"""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_outputs")
|
||||
@console_ns.doc(description="Snapshot of every node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return _serve_snapshot(app_model, run_id)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/<string:node_id>")
|
||||
class WorkflowPublishedRunNodeOutputDetailApi(Resource):
|
||||
"""One node's declared outputs + per-output status (published run)."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_detail")
|
||||
@console_ns.doc(description="One node's declared outputs for a published workflow run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str):
|
||||
return _serve_node_detail(app_model, run_id, node_id)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>"
|
||||
"/node-outputs/<string:node_id>/<string:output_name>/preview"
|
||||
)
|
||||
class WorkflowPublishedRunNodeOutputPreviewApi(Resource):
|
||||
"""Full value for one declared output of a published run."""
|
||||
|
||||
@console_ns.doc("get_workflow_published_run_node_output_preview")
|
||||
@console_ns.doc(description="Full value for one declared output of a published run.")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"app_id": "Application ID",
|
||||
"run_id": "Workflow run ID",
|
||||
"node_id": "Node ID inside the workflow graph",
|
||||
"output_name": "Declared output name as exposed by Composer",
|
||||
}
|
||||
)
|
||||
@console_ns.response(404, "Workflow run / node / output not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID, node_id: str, output_name: str):
|
||||
return _serve_output_preview(app_model, run_id, node_id, output_name)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/published/runs/<uuid:run_id>/node-outputs/events")
|
||||
class WorkflowPublishedRunNodeOutputEventsApi(Resource):
|
||||
"""SSE stream of inspector deltas for a published run."""
|
||||
|
||||
@console_ns.doc("stream_workflow_published_run_node_output_events")
|
||||
@console_ns.doc(description="Server-Sent Events stream of inspector deltas for a published workflow run.")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
return Response(
|
||||
_stream_inspector_events(app_model, run_id),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
def get(self, app_model: App, run_id: str):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
|
||||
@ -11,7 +11,7 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, AppMode
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -86,7 +86,7 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -126,7 +126,7 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
@ -166,7 +166,7 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
@ -1,38 +1,16 @@
|
||||
"""Controller decorators for console app resources.
|
||||
|
||||
`with_session` opens one SQLAlchemy session for a request handler and injects it
|
||||
as the first argument after `self`. Handlers use a transaction by default so
|
||||
migrated write paths keep commit/rollback handling; pure read handlers may opt
|
||||
out with `write=False`. App-loading decorators prefer that injected session when
|
||||
present, while still supporting existing handlers that have not been migrated
|
||||
yet and still rely on Flask-SQLAlchemy's scoped `db.session`.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, cast, overload
|
||||
from typing import overload
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(session: Session, app_id: str) -> App | None:
|
||||
"""Load the tenant-scoped app row with the request session owned by `with_session`."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def _load_app_model_from_scoped_session(app_id: str) -> App | None:
|
||||
"""Load the app row for legacy handlers that have not adopted request session injection yet."""
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
app_model = db.session.scalar(
|
||||
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
|
||||
@ -45,63 +23,6 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R],
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_session[T, **P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]: ...
|
||||
|
||||
|
||||
def with_session[T, **P, R](
|
||||
view: Callable[Concatenate[T, Session, P], R] | None = None,
|
||||
*,
|
||||
write: bool = True,
|
||||
) -> (
|
||||
Callable[Concatenate[T, P], R] | Callable[[Callable[Concatenate[T, Session, P], R]], Callable[Concatenate[T, P], R]]
|
||||
):
|
||||
"""Inject a request-scoped session, using a transaction only for write handlers."""
|
||||
|
||||
def decorator(view: Callable[Concatenate[T, Session, P], R]) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if write:
|
||||
with session_factory.get_session_maker().begin() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
return view(self, session, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def _get_injected_session(args: tuple[object, ...]) -> Session | None:
|
||||
"""Return the request session inserted by `with_session`, if this handler has been migrated."""
|
||||
if len(args) < 2:
|
||||
return None
|
||||
|
||||
candidate = args[1]
|
||||
if isinstance(candidate, Session):
|
||||
return candidate
|
||||
|
||||
if hasattr(candidate, "scalar") and hasattr(candidate, "commit") and hasattr(candidate, "rollback"):
|
||||
return cast(Session, candidate)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
@ -123,13 +44,6 @@ def get_app_model[**P, R](
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Inject the App model for handlers that receive an `app_id` path parameter.
|
||||
|
||||
New handlers may compose `@with_session` above this decorator so the app row
|
||||
is loaded through the same request-scoped session used by the controller.
|
||||
Existing handlers continue to work through `db.session` until migrated.
|
||||
"""
|
||||
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@ -141,11 +55,7 @@ def get_app_model[**P, R](
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
session = _get_injected_session(args)
|
||||
if session is None:
|
||||
app_model = _load_app_model_from_scoped_session(app_id)
|
||||
else:
|
||||
app_model = _load_app_model(session, app_id)
|
||||
app_model = _load_app_model(app_id)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
@ -5,12 +5,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from .. import console_ns
|
||||
from ..auth.error import ApiKeyAuthFailedError
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required, with_current_tenant_id
|
||||
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
|
||||
|
||||
class ApiKeyAuthBindingPayload(BaseModel):
|
||||
@ -42,8 +42,8 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
@ -69,9 +69,9 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.expect(console_ns.models[ApiKeyAuthBindingPayload.__name__])
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
payload = ApiKeyAuthBindingPayload.model_validate(console_ns.payload)
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
@ -89,9 +89,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
@console_ns.response(204, "Binding deleted successfully")
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
def delete(self, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -8,9 +8,9 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required, with_current_user
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
@ -133,10 +133,12 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, current_user: Account):
|
||||
user_account_id = current_user.id
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
|
||||
@ -48,6 +48,7 @@ class NotionEstimatePayload(BaseModel):
|
||||
class DataSourceNotionListQuery(BaseModel):
|
||||
dataset_id: str | None = Field(default=None, description="Dataset ID")
|
||||
credential_id: str = Field(..., description="Credential ID", min_length=1)
|
||||
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
|
||||
|
||||
|
||||
class DataSourceNotionPreviewQuery(BaseModel):
|
||||
@ -204,6 +205,9 @@ class DataSourceNotionListApi(Resource):
|
||||
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
|
||||
datasource_parameters = query.datasource_parameters or {}
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -251,7 +255,7 @@ class DataSourceNotionListApi(Resource):
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -979,7 +979,7 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
@ -996,7 +996,7 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
def post(self, dataset_id: str):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -5,7 +5,7 @@ from uuid import UUID
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import String, case, cast, func, literal, or_, select
|
||||
from sqlalchemy import String, cast, func, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@ -169,17 +169,9 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
# Use database-specific methods for JSON array search
|
||||
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
|
||||
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
|
||||
# Feed the set-returning function a JSON array in every row. Filtering in
|
||||
# the subquery is not enough because PostgreSQL can still evaluate the
|
||||
# SRF on scalar JSON before applying the predicate.
|
||||
keywords_jsonb = cast(DocumentSegment.keywords, JSONB)
|
||||
keywords_array = case(
|
||||
(func.jsonb_typeof(keywords_jsonb) == "array", keywords_jsonb),
|
||||
else_=cast(literal("[]"), JSONB),
|
||||
)
|
||||
keywords_condition = func.array_to_string(
|
||||
func.array(
|
||||
select(func.jsonb_array_elements_text(keywords_array))
|
||||
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
|
||||
.correlate(DocumentSegment)
|
||||
.scalar_subquery()
|
||||
),
|
||||
|
||||
@ -10,12 +10,7 @@ from controllers.common.fields import UsageCountResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
@ -131,9 +126,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_tenant_id
|
||||
@account_initialization_required
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
from controllers.common.schema import register_schema_models
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -17,8 +20,86 @@ from ..wraps import (
|
||||
setup_required,
|
||||
)
|
||||
|
||||
register_schema_models(console_ns, HitTestingPayload)
|
||||
register_response_schema_models(console_ns, HitTestingResponse)
|
||||
|
||||
class HitTestingDocument(ResponseModel):
|
||||
id: str | None = None
|
||||
data_source_type: str | None = None
|
||||
name: str | None = None
|
||||
doc_type: str | None = None
|
||||
doc_metadata: Any | None = None
|
||||
|
||||
|
||||
class HitTestingSegment(ResponseModel):
|
||||
id: str | None = None
|
||||
position: int | None = None
|
||||
document_id: str | None = None
|
||||
content: str | None = None
|
||||
sign_content: str | None = None
|
||||
answer: str | None = None
|
||||
word_count: int | None = None
|
||||
tokens: int | None = None
|
||||
keywords: list[str] = Field(default_factory=list)
|
||||
index_node_id: str | None = None
|
||||
index_node_hash: str | None = None
|
||||
hit_count: int | None = None
|
||||
enabled: bool | None = None
|
||||
disabled_at: int | None = None
|
||||
disabled_by: str | None = None
|
||||
status: str | None = None
|
||||
created_by: str | None = None
|
||||
created_at: int | None = None
|
||||
indexing_at: int | None = None
|
||||
completed_at: int | None = None
|
||||
error: str | None = None
|
||||
stopped_at: int | None = None
|
||||
document: HitTestingDocument | None = None
|
||||
|
||||
@field_validator("disabled_at", "created_at", "indexing_at", "completed_at", "stopped_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class HitTestingChildChunk(ResponseModel):
|
||||
id: str | None = None
|
||||
content: str | None = None
|
||||
position: int | None = None
|
||||
score: float | None = None
|
||||
|
||||
|
||||
class HitTestingFile(ResponseModel):
|
||||
id: str | None = None
|
||||
name: str | None = None
|
||||
size: int | None = None
|
||||
extension: str | None = None
|
||||
mime_type: str | None = None
|
||||
source_url: str | None = None
|
||||
|
||||
|
||||
class HitTestingRecord(ResponseModel):
|
||||
segment: HitTestingSegment | None = None
|
||||
child_chunks: list[HitTestingChildChunk] = Field(default_factory=list)
|
||||
score: float | None = None
|
||||
tsne_position: Any | None = None
|
||||
files: list[HitTestingFile] = Field(default_factory=list)
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class HitTestingResponse(ResponseModel):
|
||||
query: str
|
||||
records: list[HitTestingRecord] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
HitTestingPayload,
|
||||
HitTestingDocument,
|
||||
HitTestingSegment,
|
||||
HitTestingChildChunk,
|
||||
HitTestingFile,
|
||||
HitTestingRecord,
|
||||
HitTestingResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@ -38,11 +119,12 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID) -> dict[str, object]:
|
||||
def post(self, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args(console_ns.payload)
|
||||
payload = HitTestingPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
return HitTestingResponse.model_validate(self.perform_hit_testing(dataset, args)).model_dump(mode="json")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -18,10 +19,10 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -37,6 +38,16 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _extract_hit_testing_query(query: Any) -> str:
|
||||
"""Return the query string from the service response shape."""
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Ensure collection fields match the API schema before response validation."""
|
||||
@ -52,7 +63,6 @@ class DatasetsHitTestingBase:
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
normalized_segment.setdefault("sign_content", None)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
@ -63,15 +73,12 @@ class DatasetsHitTestingBase:
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_record.setdefault("tsne_position", None)
|
||||
normalized_record.setdefault("summary", None)
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str) -> Dataset:
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -85,35 +92,33 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args: dict[str, Any]) -> None:
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args(payload: dict[str, Any] | None) -> dict[str, Any]:
|
||||
def parse_args(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate and return hit-testing arguments from an incoming payload."""
|
||||
hit_testing_payload = HitTestingPayload.model_validate(payload or {})
|
||||
return hit_testing_payload.model_dump(exclude_none=True)
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]:
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=cast(str, args.get("query")),
|
||||
query=args.get("query"),
|
||||
account=current_user,
|
||||
retrieval_model=args.get("retrieval_model"),
|
||||
external_retrieval_model=cast(dict[str, Any], args.get("external_retrieval_model")),
|
||||
external_retrieval_model=args.get("external_retrieval_model"),
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
query = response.get("query")
|
||||
if not isinstance(query, dict) or not isinstance(query.get("content"), str):
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
return {
|
||||
"query": {"content": query["content"]},
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(response.get("records", [])),
|
||||
"query": DatasetsHitTestingBase._extract_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._prepare_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
@ -169,22 +168,21 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -212,12 +210,11 @@ class RagPipelineVariableApi(Resource):
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
@ -253,16 +250,15 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
@ -271,7 +267,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -282,12 +278,11 @@ class RagPipelineVariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
|
||||
@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
def get(self, pipeline: Pipeline, run_id: str):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import InstalledApp
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -41,10 +40,8 @@ register_schema_model(console_ns, TextToAudioPayload)
|
||||
endpoint="installed_app_audio",
|
||||
)
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
file = request.files["file"]
|
||||
|
||||
@ -84,10 +81,8 @@ class ChatAudioApi(InstalledAppResource):
|
||||
)
|
||||
class ChatTextApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToAudioPayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -83,10 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -135,10 +133,8 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
def post(self, installed_app, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -161,10 +157,8 @@ class CompletionStopApi(InstalledAppResource):
|
||||
)
|
||||
class ChatApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -215,10 +209,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
def post(self, installed_app, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -8,7 +8,6 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,7 +20,7 @@ from fields.conversation_fields import (
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.web_conversation_service import WebConversationService
|
||||
@ -45,10 +44,8 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
)
|
||||
class ConversationListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
|
||||
def get(self, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -95,10 +92,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
def delete(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def delete(self, installed_app, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -120,10 +115,8 @@ class ConversationApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationRenameApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
|
||||
def post(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def post(self, installed_app, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -153,10 +146,8 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
@ -179,10 +170,8 @@ class ConversationPinApi(InstalledAppResource):
|
||||
)
|
||||
class ConversationUnPinApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp, c_id: UUID):
|
||||
def patch(self, installed_app, c_id: UUID):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -262,7 +262,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
"""
|
||||
|
||||
@console_ns.response(204, "App uninstalled successfully")
|
||||
def delete(self, installed_app: InstalledApp):
|
||||
def delete(self, installed_app):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if installed_app.app_owner_tenant_id == current_tenant_id:
|
||||
raise BadRequest("You can't uninstall an app owned by the current tenant")
|
||||
@ -273,7 +273,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app: InstalledApp):
|
||||
def patch(self, installed_app):
|
||||
payload = InstalledAppUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
commit_args = False
|
||||
|
||||
@ -10,7 +10,6 @@ from controllers.common.controller_schemas import MessageFeedbackPayload, Messag
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -22,16 +21,15 @@ from controllers.console.explore.error import (
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode, InstalledApp
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@ -61,11 +59,9 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe
|
||||
)
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
@ -100,11 +96,9 @@ class MessageListApi(InstalledAppResource):
|
||||
class MessageFeedbackApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def post(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -130,11 +124,9 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
)
|
||||
class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -178,11 +170,9 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
)
|
||||
class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def get(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
@ -7,14 +7,11 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from models import Account
|
||||
from models.model import InstalledApp
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -25,11 +22,9 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app: InstalledApp):
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -51,11 +46,9 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
@ -74,11 +67,9 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app: InstalledApp, message_id: UUID):
|
||||
def delete(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -26,7 +25,7 @@ from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -42,11 +41,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
def post(self, installed_app: InstalledApp):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -9,14 +9,14 @@ from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from constants import HIDDEN_VALUE
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.api_based_extension_service import APIBasedExtensionService
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
from ..common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_models
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, setup_required, with_current_tenant_id
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
|
||||
class CodeBasedExtensionQuery(BaseModel):
|
||||
@ -116,11 +116,11 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@ -130,9 +130,9 @@ class APIBasedExtensionAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
def post(self):
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data = APIBasedExtension(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -153,12 +153,12 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, id: UUID):
|
||||
def get(self, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@ -169,9 +169,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, id: UUID):
|
||||
def post(self, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
@ -197,9 +197,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, id: UUID):
|
||||
def delete(self, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
|
||||
|
||||
@ -2,11 +2,11 @@ from flask_restx import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_user, login_required
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
@ -24,9 +24,10 @@ class FeatureApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
@ -48,9 +49,10 @@ class FeatureVectorSpaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
|
||||
@ -22,13 +22,10 @@ from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
@ -65,8 +62,8 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileResponse.__name__])
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
source_str = request.form.get("source")
|
||||
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
|
||||
|
||||
@ -110,10 +107,10 @@ class FilePreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, file_id: UUID):
|
||||
def get(self, file_id: UUID):
|
||||
file_id_str = str(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, current_tenant_id)
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -8,14 +8,8 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
@ -76,10 +70,11 @@ class NotificationApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self, current_user: Account):
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
@ -118,11 +113,11 @@ class NotificationDismissApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@ -12,13 +12,11 @@ from controllers.common.errors import (
|
||||
)
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import with_current_user
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -51,8 +49,7 @@ class RemoteFileUpload(Resource):
|
||||
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
|
||||
@console_ns.response(201, "File uploaded successfully", console_ns.models[FileWithSignedUrl.__name__])
|
||||
@login_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
def post(self):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
|
||||
@ -77,11 +74,12 @@ class RemoteFileUpload(Resource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
|
||||
@ -9,16 +9,9 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
@ -99,8 +92,8 @@ class TagListApi(Resource):
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@ -116,9 +109,9 @@ class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Allow users with edit permission, or dataset editors (including dataset operators).
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -139,8 +132,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, tag_id: UUID):
|
||||
def patch(self, tag_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -170,19 +163,20 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission(current_user: Account) -> None:
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
@ -195,8 +189,8 @@ def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
@ -219,9 +213,8 @@ class TagBindingCollectionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _create_tag_bindings(current_user)
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
@ -235,6 +228,5 @@ class TagBindingRemoveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _remove_tag_bindings(current_user)
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
|
||||
@ -4,7 +4,6 @@ from uuid import UUID
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy import func, select
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@ -23,15 +22,15 @@ from controllers.console.auth.error import (
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
is_allow_transfer_owner,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.member_fields import AccountWithRole, AccountWithRoleList
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
from services.feature_service import FeatureService
|
||||
@ -77,55 +76,7 @@ register_response_schema_models(console_ns, SimpleResultDataResponse, Verificati
|
||||
def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
|
||||
if role != TenantAccountRole.DATASET_OPERATOR:
|
||||
return True
|
||||
return FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True).dataset_operator_enabled
|
||||
|
||||
|
||||
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(email.lower() for email in emails))
|
||||
|
||||
|
||||
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
|
||||
new_member_count = 0
|
||||
for email in emails:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email)
|
||||
if not account:
|
||||
new_member_count += 1
|
||||
continue
|
||||
|
||||
exists = db.session.scalar(
|
||||
select(TenantAccountJoin.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
|
||||
.limit(1)
|
||||
)
|
||||
if not exists:
|
||||
new_member_count += 1
|
||||
|
||||
return new_member_count
|
||||
|
||||
|
||||
def _count_current_members(tenant_id: str) -> int:
|
||||
return (
|
||||
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
|
||||
)
|
||||
|
||||
|
||||
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
|
||||
if new_member_count <= 0:
|
||||
return
|
||||
|
||||
features = FeatureService.get_features(tenant_id=tenant_id, exclude_vector_space=True)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
workspace_members = features.workspace_members
|
||||
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
return
|
||||
|
||||
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
|
||||
members = features.members
|
||||
current_member_count = _count_current_members(tenant_id)
|
||||
if 0 < members.limit < current_member_count + new_member_count:
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/members")
|
||||
@ -154,11 +105,12 @@ class MemberInviteEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
payload = console_ns.payload or {}
|
||||
args = MemberInvitePayload.model_validate(payload)
|
||||
|
||||
invitee_emails = _normalize_invitee_emails(args.emails)
|
||||
invitee_emails = args.emails
|
||||
invitee_role = args.role
|
||||
interface_language = args.language
|
||||
if not TenantAccountRole.is_non_owner_role(invitee_role):
|
||||
@ -178,36 +130,37 @@ class MemberInviteEmailApi(Resource):
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
tenant_id = inviter.current_tenant.id
|
||||
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
|
||||
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
|
||||
_check_member_invite_limits(tenant_id, new_member_count)
|
||||
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
|
||||
if not workspace_members.is_available(len(invitee_emails)):
|
||||
raise WorkspaceMembersLimitExceeded()
|
||||
|
||||
for invitee_email in invitee_emails:
|
||||
normalized_invitee_email = invitee_email.lower()
|
||||
try:
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=inviter.current_tenant,
|
||||
email=invitee_email,
|
||||
language=interface_language,
|
||||
role=invitee_role,
|
||||
inviter=inviter,
|
||||
)
|
||||
encoded_invitee_email = parse.quote(normalized_invitee_email)
|
||||
invitation_results.append(
|
||||
{
|
||||
"status": "success",
|
||||
"email": normalized_invitee_email,
|
||||
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
|
||||
}
|
||||
)
|
||||
except AccountAlreadyInTenantError:
|
||||
invitation_results.append(
|
||||
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
|
||||
)
|
||||
except Exception as e:
|
||||
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
||||
@ -8,17 +8,12 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
@ -143,8 +138,9 @@ class DefaultModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -160,8 +156,9 @@ class DefaultModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str):
|
||||
def post(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserPostDefault.model_validate(console_ns.payload)
|
||||
model_provider_service = ModelProviderService()
|
||||
model_settings = args.model_settings
|
||||
@ -192,8 +189,9 @@ class ModelProviderModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider):
|
||||
def get(self, provider):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
@ -204,9 +202,9 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
def post(self, provider: str):
|
||||
# To save the model's load balance configs
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserPostModels.model_validate(console_ns.payload)
|
||||
|
||||
if args.config_from == "custom-model":
|
||||
@ -251,8 +249,9 @@ class ModelProviderModelApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, provider: str):
|
||||
def delete(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -269,8 +268,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
def get(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -323,8 +323,9 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserCreateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -354,8 +355,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, provider: str):
|
||||
def put(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = ParserUpdateCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -381,8 +382,8 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider: str):
|
||||
def delete(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = ParserDeleteCredential.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -405,8 +406,8 @@ class ModelProviderModelCredentialSwitchApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider: str):
|
||||
def post(self, provider: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = ParserSwitch.model_validate(console_ns.payload)
|
||||
|
||||
service = ModelProviderService()
|
||||
@ -429,8 +430,9 @@ class ModelProviderModelEnableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -450,8 +452,9 @@ class ModelProviderModelDisableApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, provider: str):
|
||||
def patch(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = ParserDeleteModels.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -477,8 +480,8 @@ class ModelProviderModelValidateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, provider: str):
|
||||
def post(self, provider: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
args = ParserValidate.model_validate(console_ns.payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@ -512,9 +515,9 @@ class ModelProviderModelParameterRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, provider: str):
|
||||
def get(self, provider: str):
|
||||
args = ParserParameter.model_validate(request.args.to_dict(flat=True))
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
parameter_rules = model_provider_service.get_model_parameter_rules(
|
||||
@ -529,8 +532,8 @@ class ModelProviderAvailableModelApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, model_type: str):
|
||||
def get(self, model_type: str):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
model_provider_service = ModelProviderService()
|
||||
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
|
||||
@ -166,10 +166,10 @@ class TenantListApi(Resource):
|
||||
if tenant_plan:
|
||||
plan = tenant_plan["plan"] or CloudPlan.SANDBOX
|
||||
else:
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@ -17,7 +16,6 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.encryption import FieldEncryption
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
@ -84,7 +82,9 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -96,28 +96,21 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(current_tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403,
|
||||
"The capacity of the knowledge storage space has reached the limit of your subscription.",
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(
|
||||
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
|
||||
)
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places,
|
||||
# so we need to check the source of the request from datasets
|
||||
@ -147,7 +140,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
@ -205,11 +198,15 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
if dify_config.BILLING_ENABLED and utm_info:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -298,7 +295,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.knowledge_pipeline.publish_enabled:
|
||||
return view(*args, **kwargs)
|
||||
abort(403)
|
||||
@ -312,6 +309,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object() # type: ignore
|
||||
if not isinstance(user, Account):
|
||||
@ -329,6 +327,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
|
||||
user = current_user._get_current_object()
|
||||
if not isinstance(user, Account) or not user.is_admin_or_owner:
|
||||
@ -496,25 +495,3 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_tenant_id[T, **P, R](
|
||||
view: Callable[Concatenate[T, str, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
return view(self, current_tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -1,142 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
# Register response/query models BEFORE importing controller modules so that
|
||||
# @openapi_ns.response / @openapi_ns.expect decorators can resolve model names.
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppInfoResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
MemberListQuery,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
MemberRoleUpdatePayload,
|
||||
MessageMetadata,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
ServerVersionResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
WorkflowRunData,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspacePayload,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from fields.file_fields import FileResponse
|
||||
|
||||
register_schema_models(
|
||||
openapi_ns,
|
||||
AppDescribeQuery,
|
||||
AppListQuery,
|
||||
AppRunRequest,
|
||||
DeviceCodeRequest,
|
||||
DevicePollRequest,
|
||||
DeviceLookupQuery,
|
||||
DeviceMutateRequest,
|
||||
MemberInvitePayload,
|
||||
MemberListQuery,
|
||||
MemberRoleUpdatePayload,
|
||||
PermittedExternalAppsListQuery,
|
||||
)
|
||||
register_response_schema_models(
|
||||
openapi_ns,
|
||||
TagItem,
|
||||
UsageInfo,
|
||||
MessageMetadata,
|
||||
AppListRow,
|
||||
AppListResponse,
|
||||
AppInfoResponse,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
WorkflowRunData,
|
||||
AccountPayload,
|
||||
WorkspacePayload,
|
||||
AccountResponse,
|
||||
SessionRow,
|
||||
SessionListResponse,
|
||||
PermittedExternalAppsListResponse,
|
||||
RevokeResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceDetailResponse,
|
||||
MemberResponse,
|
||||
MemberListResponse,
|
||||
MemberInviteResponse,
|
||||
MemberActionResponse,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateResponse,
|
||||
FileResponse,
|
||||
ServerVersionResponse,
|
||||
)
|
||||
|
||||
from . import (
|
||||
_meta,
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted_external,
|
||||
files,
|
||||
human_input_form,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workflow_events,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
# Request models are imported from _models.py and registered above.
|
||||
|
||||
__all__ = [
|
||||
"_meta",
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted_external",
|
||||
"files",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workflow_events",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
@ -1,66 +0,0 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED = "openapi.wrong_surface_denied"
|
||||
|
||||
|
||||
def emit_app_run(
|
||||
*,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
caller_kind: str,
|
||||
mode: str,
|
||||
surface: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s surface=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
surface,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
"surface": surface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def emit_wrong_surface(
|
||||
*,
|
||||
subject_type: str | None,
|
||||
attempted_path: str,
|
||||
client_id: str | None,
|
||||
token_id: str | None,
|
||||
) -> None:
|
||||
logger.warning(
|
||||
"audit: %s subject_type=%s attempted_path=%s",
|
||||
EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
subject_type,
|
||||
attempted_path,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_OPENAPI_WRONG_SURFACE_DENIED,
|
||||
"subject_type": subject_type,
|
||||
"attempted_path": attempted_path,
|
||||
"client_id": client_id,
|
||||
"token_id": token_id,
|
||||
},
|
||||
)
|
||||
@ -1,143 +0,0 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out: dict[str, Any] = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
"""Meta endpoint: `GET /openapi/v1/_version` — no auth.
|
||||
|
||||
Returns the server's project version and edition so the difyctl CLI can probe
|
||||
compatibility without needing to be logged in. Mirrors the `_health` endpoint
|
||||
in `index.py`.
|
||||
"""
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import ServerVersionResponse
|
||||
|
||||
|
||||
@openapi_ns.route("/_version")
|
||||
class VersionApi(Resource):
|
||||
@openapi_ns.response(200, "Server version", openapi_ns.models[ServerVersionResponse.__name__])
|
||||
def get(self):
|
||||
edition = dify_config.EDITION if dify_config.EDITION in ("SELF_HOSTED", "CLOUD") else "SELF_HOSTED"
|
||||
return ServerVersionResponse(
|
||||
version=dify_config.project.version,
|
||||
edition=edition,
|
||||
).model_dump(mode="json")
|
||||
@ -1,402 +0,0 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from libs.helper import EmailStr, UUIDStrOrEmpty, uuid_value
|
||||
from models.model import AppMode
|
||||
|
||||
# Server-side cap on `limit` query param for /openapi/v1/* list endpoints.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class TagItem(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: AppMode
|
||||
tags: list[TagItem] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class PermittedExternalAppsListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[AppListRow]
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[TagItem] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
is_agent: bool = False
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
|
||||
|
||||
class AccountPayload(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
|
||||
|
||||
class WorkspacePayload(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class AccountResponse(BaseModel):
|
||||
subject_type: str
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account: AccountPayload | None = None
|
||||
workspaces: list[WorkspacePayload] = []
|
||||
default_workspace_id: str | None = None
|
||||
|
||||
|
||||
class SessionRow(BaseModel):
|
||||
id: str
|
||||
prefix: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
created_at: str | None = None
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[SessionRow]
|
||||
|
||||
|
||||
class RevokeResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class WorkspaceSummaryResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
workspaces: list[WorkspaceSummaryResponse]
|
||||
|
||||
|
||||
class WorkspaceDetailResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
status: str
|
||||
current: bool
|
||||
created_at: str | None = None
|
||||
|
||||
|
||||
class DeviceCodeResponse(BaseModel):
|
||||
device_code: str
|
||||
user_code: str
|
||||
verification_uri: str
|
||||
expires_in: int
|
||||
interval: int
|
||||
|
||||
|
||||
class DeviceLookupResponse(BaseModel):
|
||||
valid: bool
|
||||
expires_in_remaining: int = 0
|
||||
client_id: str | None = None
|
||||
|
||||
|
||||
class DeviceMutateResponse(BaseModel):
|
||||
status: str
|
||||
|
||||
|
||||
class ServerVersionResponse(BaseModel):
|
||||
"""Meta endpoint payload for `GET /openapi/v1/_version` — no auth required."""
|
||||
|
||||
version: str
|
||||
edition: Literal["SELF_HOSTED", "CLOUD"]
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
import uuid as _uuid
|
||||
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
_ALLOWED_DESCRIBE_FIELDS = frozenset({"info", "parameters", "input_schema"})
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""mode is a closed enum."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
workspace_id: UUIDStrOrEmpty | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class PermittedExternalAppsListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
|
||||
|
||||
|
||||
class ExtSubjectAssertionClaims(BaseModel):
|
||||
email: str = _EMAIL_FIELD
|
||||
issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
class ApprovalGrantClaimsPayload(BaseModel):
|
||||
subject_email: str = _EMAIL_FIELD
|
||||
subject_issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
# Closed enum for invite/update-role payloads. Owner is intentionally not
|
||||
# assignable through these endpoints — ownership transfer goes through the
|
||||
# console's three-step email-verification flow.
|
||||
MemberAssignableRole = Literal["normal", "admin"]
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
avatar: str | None = None
|
||||
|
||||
|
||||
class MemberListResponse(BaseModel):
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[MemberResponse]
|
||||
|
||||
|
||||
class MemberListQuery(BaseModel):
|
||||
"""Strict (extra='forbid')."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
|
||||
|
||||
class MemberInvitePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
email: EmailStr
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberRoleUpdatePayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
role: MemberAssignableRole
|
||||
|
||||
|
||||
class MemberInviteResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
email: str
|
||||
role: str
|
||||
member_id: str
|
||||
invite_url: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class MemberActionResponse(BaseModel):
|
||||
result: Literal["success"] = "success"
|
||||
@ -1,144 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AccountPayload,
|
||||
AccountResponse,
|
||||
PaginationEnvelope,
|
||||
RevokeResponse,
|
||||
SessionListResponse,
|
||||
SessionRow,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
Scope,
|
||||
TokenType,
|
||||
get_auth_ctx,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
enforce,
|
||||
)
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.oauth_device_flow import (
|
||||
list_active_sessions,
|
||||
revoke_oauth_token,
|
||||
token_belongs_to_subject,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
|
||||
|
||||
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
|
||||
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
|
||||
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return AccountResponse(
|
||||
subject_type="account",
|
||||
subject_email=account.email if account else None,
|
||||
account=_account_payload(account) if account else None,
|
||||
workspaces=[_workspace_payload(m) for m in memberships],
|
||||
default_workspace_id=default_ws_id,
|
||||
).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, *, auth_data: AuthData):
|
||||
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = list_active_sessions(db.session, ctx, now)
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
SessionRow(
|
||||
id=str(r.id),
|
||||
prefix=r.prefix,
|
||||
client_id=r.client_id,
|
||||
device_label=r.device_label,
|
||||
created_at=_iso(r.created_at),
|
||||
last_used_at=_iso(r.last_used_at),
|
||||
expires_at=_iso(r.expires_at),
|
||||
)
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def delete(self, session_id: str, *, auth_data: AuthData):
|
||||
ctx = get_auth_ctx()
|
||||
|
||||
# 404 (not 403) on cross-subject so the endpoint doesn't leak
|
||||
# token IDs that belong to other subjects.
|
||||
if not token_belongs_to_subject(db.session, session_id, ctx):
|
||||
raise NotFound("session not found")
|
||||
|
||||
revoke_oauth_token(db.session, redis_client, session_id)
|
||||
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> WorkspacePayload:
|
||||
join, tenant = row
|
||||
return WorkspacePayload(id=str(tenant.id), name=tenant.name, role=getattr(join, "role", ""))
|
||||
|
||||
|
||||
def _account_payload(account) -> AccountPayload:
|
||||
return AccountPayload(id=str(account.id), email=account.email, name=account.name)
|
||||
@ -1,168 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import AppRunRequest
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
return _generate(app, caller, args, streaming=True)
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
|
||||
@openapi_ns.response(200, "Run result (SSE stream)")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
try:
|
||||
stream_obj = handler(app_model, caller, payload)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
surface="apps",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
|
||||
class AppRunTaskStopApi(Resource):
|
||||
@openapi_ns.response(200, "Task stopped")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
GraphEngineManager(redis_client).send_stop_command(task_id)
|
||||
return {"result": "success"}
|
||||
@ -1,243 +0,0 @@
|
||||
"""GET /openapi/v1/apps and per-app reads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
AppDescribeInfo,
|
||||
AppDescribeQuery,
|
||||
AppDescribeResponse,
|
||||
AppListQuery,
|
||||
AppListResponse,
|
||||
AppListRow,
|
||||
TagItem,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppListParams, AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
# ``str(parsed_uuid)`` normalises to the canonical dashed form.
|
||||
app = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None:
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = AppService.find_visible_apps_by_name(db.session, name=app_id, tenant_id=workspace_id)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
|
||||
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, app_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
is_agent=app.mode in ("agent-chat", "advanced-chat"),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
|
||||
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
|
||||
empty = (
|
||||
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = AppService.get_visible_app_by_id(db.session, str(parsed_uuid))
|
||||
if app is None or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[TagItem(name=t.name) for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = AppListResponse(page=1, limit=1, total=1, has_more=False, data=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
params = AppListParams(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else "all", # type:ignore
|
||||
name=query.name,
|
||||
tag_ids=tag_ids,
|
||||
status="normal",
|
||||
# Visibility gate pushed into the query — pagination.total stays
|
||||
# consistent across pages because invisible rows never count.
|
||||
openapi_visible=True,
|
||||
)
|
||||
|
||||
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name = None
|
||||
if pagination.items:
|
||||
tenant_name = TenantService.get_tenant_name(db.session, workspace_id)
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[TagItem(name=t.name) for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
|
||||
env = AppListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=cast(int, pagination.total),
|
||||
has_more=query.page * query.limit < cast(int, pagination.total),
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,92 +0,0 @@
|
||||
"""GET /openapi/v1/permitted-external-apps — external-subject app discovery (EE only).
|
||||
|
||||
`dfoe_` (External SSO) callers reach apps gated by ACL access-mode
|
||||
(public / sso_verified). License-gated: CE deploys never enable the
|
||||
EE blueprint chain so this module is unreachable there.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AppListRow,
|
||||
PermittedExternalAppsListQuery,
|
||||
PermittedExternalAppsListResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData, Edition
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
|
||||
@openapi_ns.route("/permitted-external-apps")
|
||||
class PermittedExternalAppsListApi(Resource):
|
||||
@openapi_ns.response(
|
||||
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
|
||||
)
|
||||
@auth_router.guard(
|
||||
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
|
||||
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
|
||||
edition=frozenset({Edition.EE}),
|
||||
)
|
||||
def get(self, *, auth_data: AuthData):
|
||||
try:
|
||||
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page, limit=query.limit, total=page_result.total, has_more=False, data=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id: dict[str, App] = {
|
||||
str(a.id): a for a in AppService.find_visible_apps_by_ids(db.session, page_result.app_ids)
|
||||
}
|
||||
tenant_ids = list({str(a.tenant_id) for a in apps_by_id.values()})
|
||||
tenants_by_id = {str(t.id): t for t in TenantService.get_tenants_by_ids(db.session, tenant_ids)}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
env = PermittedExternalAppsListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=page_result.total,
|
||||
has_more=query.page * query.limit < page_result.total,
|
||||
data=items,
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
@ -1,3 +0,0 @@
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
|
||||
__all__ = ["auth_router"]
|
||||
@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.conditions import (
|
||||
EDITION_CE,
|
||||
EDITION_EE,
|
||||
LOADED_APP_IS_PRIVATE,
|
||||
PATH_HAS_APP_ID,
|
||||
WEBAPP_AUTH_ENABLED,
|
||||
)
|
||||
from controllers.openapi.auth.data import Edition
|
||||
from controllers.openapi.auth.flow import When
|
||||
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
|
||||
from controllers.openapi.auth.prepare import (
|
||||
load_account,
|
||||
load_app,
|
||||
load_app_access_mode,
|
||||
load_tenant,
|
||||
resolve_external_user,
|
||||
)
|
||||
from controllers.openapi.auth.verify import (
|
||||
check_acl,
|
||||
check_app_access,
|
||||
check_membership,
|
||||
check_private_app_permission,
|
||||
check_scope,
|
||||
)
|
||||
from libs.oauth_bearer import TokenType
|
||||
|
||||
account_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
load_account, # all tokens here are account tokens
|
||||
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
|
||||
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
|
||||
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
external_sso_pipeline = AuthPipeline(
|
||||
prepare=[
|
||||
When(PATH_HAS_APP_ID, then=load_app),
|
||||
When(PATH_HAS_APP_ID, then=load_tenant),
|
||||
When(PATH_HAS_APP_ID, then=resolve_external_user),
|
||||
When(PATH_HAS_APP_ID, then=load_app_access_mode),
|
||||
],
|
||||
auth=[
|
||||
check_scope,
|
||||
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
|
||||
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
|
||||
],
|
||||
)
|
||||
|
||||
auth_router = PipelineRouter(
|
||||
{
|
||||
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
|
||||
}
|
||||
)
|
||||
@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
|
||||
from libs.oauth_bearer import TokenType
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
CondFn = Callable[[RequestContext, AuthData | None], bool]
|
||||
|
||||
|
||||
class Cond:
|
||||
def __init__(self, fn: CondFn) -> None:
|
||||
self._fn = fn
|
||||
|
||||
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self._fn(ctx, data)
|
||||
|
||||
def __and__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
|
||||
|
||||
def __or__(self, other: Cond) -> Cond:
|
||||
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
|
||||
|
||||
def __invert__(self) -> Cond:
|
||||
return Cond(lambda ctx, data: not self(ctx, data))
|
||||
|
||||
|
||||
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
|
||||
return Cond(lambda ctx, _: fn(ctx))
|
||||
|
||||
|
||||
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
|
||||
return Cond(lambda _, data: data is not None and fn(data))
|
||||
|
||||
|
||||
def config_cond(fn: Callable[[], bool]) -> Cond:
|
||||
return Cond(lambda _, __: fn())
|
||||
|
||||
|
||||
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
|
||||
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
|
||||
|
||||
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
|
||||
|
||||
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
|
||||
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
|
||||
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
|
||||
|
||||
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
|
||||
|
||||
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)
|
||||
@ -1,69 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from configs import dify_config
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models.account import Account, Tenant
|
||||
from models.model import App, EndUser
|
||||
from services.enterprise.enterprise_service import WebAppAccessMode
|
||||
|
||||
|
||||
class Edition(StrEnum):
|
||||
CE = "ce"
|
||||
EE = "ee"
|
||||
SAAS = "saas"
|
||||
|
||||
|
||||
def current_edition() -> Edition:
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
return Edition.SAAS
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return Edition.EE
|
||||
return Edition.CE
|
||||
|
||||
|
||||
class ExternalIdentity(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
email: str
|
||||
issuer: str | None = None
|
||||
|
||||
|
||||
class RequestContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
token_type: TokenType
|
||||
scope: Scope | None = None
|
||||
path_params: dict[str, str]
|
||||
|
||||
|
||||
class AuthData(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
required_scope: Scope | None = None
|
||||
token_type: TokenType
|
||||
account_id: uuid.UUID | None = None
|
||||
token_hash: str
|
||||
token_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope]
|
||||
tenants: dict[str, bool] = Field(default_factory=dict)
|
||||
external_identity: ExternalIdentity | None = None
|
||||
path_params: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
app_access_mode: WebAppAccessMode | None = None
|
||||
|
||||
caller: Account | EndUser | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]:
|
||||
if self.app is None or self.caller is None or self.caller_kind is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app context missing")
|
||||
return self.app, self.caller, self.caller_kind
|
||||
@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from controllers.openapi.auth.conditions import Cond
|
||||
from controllers.openapi.auth.data import AuthData, RequestContext
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
|
||||
self.condition = condition
|
||||
self._step = then
|
||||
|
||||
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
|
||||
return self.condition(ctx, data)
|
||||
|
||||
def __call__(self, arg: Any) -> None:
|
||||
self._step(arg)
|
||||
@ -1,209 +0,0 @@
|
||||
"""Auth pipeline — entry point for all openapi auth.
|
||||
|
||||
`PipelineRouter.guard()` is the only attachment point for endpoints.
|
||||
`AuthPipeline` is a pure step-runner with no routing concerns.
|
||||
`PipelineRoute` binds a pipeline to optional edition requirements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from controllers.openapi.auth.data import (
|
||||
AuthData,
|
||||
Edition,
|
||||
ExternalIdentity,
|
||||
RequestContext,
|
||||
current_edition,
|
||||
)
|
||||
from controllers.openapi.auth.flow import When
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
TokenType,
|
||||
extract_bearer,
|
||||
get_authenticator,
|
||||
reset_auth_ctx,
|
||||
set_auth_ctx,
|
||||
)
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
|
||||
class AuthPipeline:
|
||||
"""Pure step-runner — no routing, no guard.
|
||||
|
||||
Both `prepare` and `auth` steps receive the same `AuthData` instance.
|
||||
`prepare` steps populate it; `auth` steps validate it.
|
||||
"""
|
||||
|
||||
def __init__(self, prepare: list, auth: list) -> None:
|
||||
self._prepare = prepare
|
||||
self._auth = auth
|
||||
|
||||
def _run(
|
||||
self,
|
||||
identity: AuthContext,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
) -> Any:
|
||||
req_ctx = RequestContext(
|
||||
token_type=identity.token_type,
|
||||
scope=scope,
|
||||
path_params=dict(request.view_args or {}),
|
||||
)
|
||||
|
||||
data = AuthData(
|
||||
token_type=identity.token_type,
|
||||
account_id=identity.account_id,
|
||||
token_hash=identity.token_hash,
|
||||
token_id=identity.token_id,
|
||||
scopes=frozenset(identity.scopes),
|
||||
tenants=dict(identity.verified_tenants),
|
||||
required_scope=scope,
|
||||
path_params=dict(req_ctx.path_params),
|
||||
external_identity=(
|
||||
ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer)
|
||||
if identity.subject_email
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
for step in self._prepare:
|
||||
if _should_run(step, req_ctx, data=None):
|
||||
step(data)
|
||||
|
||||
for step in self._auth:
|
||||
if _should_run(step, req_ctx, data=data):
|
||||
step(data)
|
||||
|
||||
reset_token = set_auth_ctx(identity)
|
||||
if data.caller:
|
||||
_mount_flask_login(data.caller)
|
||||
|
||||
try:
|
||||
kwargs["auth_data"] = data
|
||||
return view(*args, **kwargs)
|
||||
finally:
|
||||
reset_auth_ctx(reset_token)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PipelineRoute:
|
||||
pipeline: AuthPipeline
|
||||
required_edition: frozenset[Edition] | None = None
|
||||
|
||||
|
||||
class PipelineRouter:
|
||||
"""Entry point for openapi auth.
|
||||
|
||||
`guard()` is the decorator that endpoints attach to. It applies
|
||||
global gates (edition, token type) then dispatches to the matching
|
||||
`PipelineRoute` for the token type.
|
||||
"""
|
||||
|
||||
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
|
||||
self._routes = routes
|
||||
|
||||
def guard(
|
||||
self,
|
||||
*,
|
||||
scope: Scope | None = None,
|
||||
allowed_token_types: frozenset[TokenType] | None = None,
|
||||
edition: frozenset[Edition] | None = None,
|
||||
) -> Callable:
|
||||
def decorator(view: Callable) -> Callable:
|
||||
@wraps(view)
|
||||
def decorated(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._execute(
|
||||
args,
|
||||
kwargs,
|
||||
view,
|
||||
scope=scope,
|
||||
allowed_token_types=allowed_token_types,
|
||||
edition=edition,
|
||||
)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
args: tuple,
|
||||
kwargs: dict,
|
||||
view: Callable,
|
||||
*,
|
||||
scope: Scope | None,
|
||||
allowed_token_types: frozenset[TokenType] | None,
|
||||
edition: frozenset[Edition] | None,
|
||||
) -> Any:
|
||||
# 404 not 403 — this edition doesn't expose the feature at all
|
||||
if edition is not None and current_edition() not in edition:
|
||||
raise NotFound()
|
||||
|
||||
license_checked = False
|
||||
if edition is not None and Edition.EE in edition:
|
||||
_check_license()
|
||||
license_checked = True
|
||||
|
||||
token = extract_bearer(request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
identity = get_authenticator().authenticate(token)
|
||||
|
||||
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
|
||||
emit_wrong_surface(
|
||||
subject_type=_subject_type_str(identity),
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(identity, "client_id", None),
|
||||
token_id=str(identity.token_id) if identity.token_id else None,
|
||||
)
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
route = self._routes.get(identity.token_type)
|
||||
if route is None:
|
||||
raise Forbidden("unsupported_token_type")
|
||||
|
||||
if route.required_edition is not None:
|
||||
if current_edition() not in route.required_edition:
|
||||
raise Forbidden("external_sso_requires_ee")
|
||||
if not license_checked and Edition.EE in route.required_edition:
|
||||
_check_license()
|
||||
|
||||
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
|
||||
|
||||
|
||||
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
|
||||
if isinstance(step, When):
|
||||
return step.applies(req_ctx, data)
|
||||
return True
|
||||
|
||||
|
||||
def _subject_type_str(identity: Any) -> str | None:
|
||||
subject = getattr(identity, "subject_type", None)
|
||||
if subject is None:
|
||||
return None
|
||||
return subject.value if hasattr(subject, "value") else str(subject)
|
||||
|
||||
|
||||
def _check_license() -> None:
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
|
||||
raise Forbidden("license_invalid")
|
||||
|
||||
|
||||
def _mount_flask_login(user: Any) -> None:
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
|
||||
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]
|
||||
@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def load_app(data: AuthData) -> None:
|
||||
app_id = data.path_params["app_id"]
|
||||
app = AppService.get_app_by_id(db.session, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
data.app = app
|
||||
|
||||
|
||||
def load_tenant(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant")
|
||||
tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id))
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
data.tenant = tenant
|
||||
|
||||
|
||||
def load_account(data: AuthData) -> None:
|
||||
account = AccountService.get_account_by_id(db.session, str(data.account_id))
|
||||
if account is None:
|
||||
raise Unauthorized("account not found")
|
||||
if data.tenant:
|
||||
account.current_tenant = data.tenant
|
||||
data.caller = account
|
||||
data.caller_kind = "account"
|
||||
|
||||
|
||||
def resolve_external_user(data: AuthData) -> None:
|
||||
if data.tenant is None or data.app is None or data.external_identity is None:
|
||||
raise Unauthorized("missing context for external user resolution")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=str(data.tenant.id),
|
||||
app_id=str(data.app.id),
|
||||
user_id=data.external_identity.email,
|
||||
)
|
||||
data.caller = end_user
|
||||
data.caller_kind = "end_user"
|
||||
|
||||
|
||||
def load_app_access_mode(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
return
|
||||
try:
|
||||
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id))
|
||||
if settings is None:
|
||||
data.app_access_mode = None
|
||||
return
|
||||
data.app_access_mode = WebAppAccessMode(settings.access_mode)
|
||||
except ValueError:
|
||||
data.app_access_mode = None
|
||||
@ -1,77 +0,0 @@
|
||||
"""Workspace role gate.
|
||||
|
||||
Layered on top of `validate_bearer` + `accept_subjects(SubjectType.ACCOUNT)`
|
||||
for routes whose access depends on the caller's `TenantAccountJoin.role`
|
||||
in the workspace named by the `workspace_id` path parameter.
|
||||
|
||||
Usage::
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class Members(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role() # any member
|
||||
def get(self, workspace_id: str): ...
|
||||
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
@accept_subjects(SubjectType.ACCOUNT)
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str): ...
|
||||
|
||||
Non-member callers get 404 (matching `GET /openapi/v1/workspaces/<id>`)
|
||||
so workspace IDs do not leak across tenants. A member without one of the
|
||||
allowed roles gets 403.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import try_get_auth_ctx
|
||||
from models.account import TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def require_workspace_role(*allowed_roles: TenantAccountRole) -> Callable[[F], F]:
|
||||
"""Gate a route on the caller's role in ``workspace_id``.
|
||||
|
||||
Pass no roles to require only membership. Pass one or more roles to
|
||||
require the caller's role be in that set.
|
||||
"""
|
||||
|
||||
allowed = frozenset(allowed_roles)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None or ctx.account_id is None:
|
||||
raise RuntimeError(
|
||||
"require_workspace_role called without account-bearer context; "
|
||||
"stack validate_bearer + accept_subjects(SubjectType.ACCOUNT) above it"
|
||||
)
|
||||
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
if not workspace_id:
|
||||
raise RuntimeError("require_workspace_role expects a 'workspace_id' route parameter")
|
||||
|
||||
role = TenantService.get_account_role_in_tenant(db.session, str(ctx.account_id), str(workspace_id))
|
||||
|
||||
if role is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
if allowed and role not in allowed:
|
||||
raise Forbidden("insufficient workspace role")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
@ -1,89 +0,0 @@
|
||||
"""Surface gate.
|
||||
|
||||
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
|
||||
step) is the pipeline-level form. Both delegate to `check_surface` so the
|
||||
audit emit + canonical-path message are single-sourced.
|
||||
|
||||
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
|
||||
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
|
||||
``openapi.wrong_surface_denied``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import request
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi._audit import emit_wrong_surface
|
||||
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
|
||||
|
||||
_CANONICAL_PATH: dict[SubjectType, str] = {
|
||||
SubjectType.ACCOUNT: "/openapi/v1/apps",
|
||||
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
|
||||
}
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., object])
|
||||
|
||||
|
||||
def check_surface(accepted: frozenset[SubjectType]) -> None:
|
||||
"""Enforce that the resolved subject is in ``accepted``.
|
||||
|
||||
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
|
||||
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
|
||||
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
|
||||
set the bearer layer didn't run — that's a wiring bug, not a
|
||||
user-driven failure, so surface it as a ``RuntimeError`` instead of
|
||||
a silent 403.
|
||||
"""
|
||||
ctx = try_get_auth_ctx()
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
|
||||
)
|
||||
|
||||
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
|
||||
if subject in accepted:
|
||||
return
|
||||
|
||||
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
|
||||
emit_wrong_surface(
|
||||
subject_type=subject.value if subject else None,
|
||||
attempted_path=request.path,
|
||||
client_id=getattr(ctx, "client_id", None),
|
||||
token_id=_stringify(getattr(ctx, "token_id", None)),
|
||||
)
|
||||
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
|
||||
|
||||
|
||||
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
|
||||
accepted_set: frozenset[SubjectType] = frozenset(accepted)
|
||||
|
||||
def deco(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
check_surface(accepted_set)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return deco
|
||||
|
||||
|
||||
def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
@ -1,82 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from werkzeug.exceptions import Forbidden, Unauthorized
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
|
||||
|
||||
|
||||
def check_scope(data: AuthData) -> None:
|
||||
if data.required_scope is None:
|
||||
return
|
||||
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
def check_membership(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
raise Unauthorized("tenant unset")
|
||||
if data.account_id is None:
|
||||
raise Unauthorized("account_id unset")
|
||||
check_workspace_membership(
|
||||
account_id=data.account_id,
|
||||
tenant_id=data.tenant.id,
|
||||
token_hash=data.token_hash,
|
||||
membership_cache=data.tenants,
|
||||
)
|
||||
|
||||
|
||||
def check_app_access(data: AuthData) -> None:
|
||||
if data.tenant is None:
|
||||
return
|
||||
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
|
||||
TokenType.OAUTH_ACCOUNT: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
WebAppAccessMode.PRIVATE_ALL,
|
||||
WebAppAccessMode.PRIVATE,
|
||||
}
|
||||
),
|
||||
TokenType.OAUTH_EXTERNAL_SSO: frozenset(
|
||||
{
|
||||
WebAppAccessMode.PUBLIC,
|
||||
WebAppAccessMode.SSO_VERIFIED,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def check_acl(data: AuthData) -> None:
|
||||
if data.app is None or data.app_access_mode is None:
|
||||
raise Forbidden("app or access mode not loaded")
|
||||
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
|
||||
if data.app_access_mode not in allowed_modes:
|
||||
raise Forbidden("subject_not_allowed_for_access_mode")
|
||||
|
||||
|
||||
def check_private_app_permission(data: AuthData) -> None:
|
||||
if data.app is None:
|
||||
raise Forbidden("app not loaded")
|
||||
user_id = _resolve_user_id(data)
|
||||
if user_id is None:
|
||||
raise Forbidden("cannot resolve user for private app check")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id):
|
||||
raise Forbidden("user_not_allowed_for_private_app")
|
||||
|
||||
|
||||
def _resolve_user_id(data: AuthData) -> str | None:
|
||||
if data.token_type == TokenType.OAUTH_ACCOUNT:
|
||||
return str(data.account_id) if data.account_id is not None else None
|
||||
if data.external_identity is None:
|
||||
return None
|
||||
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
|
||||
return str(account.id) if account is not None else None
|
||||
@ -1,73 +0,0 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/files/upload — upload a file for use in app inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx.api import HTTPStatus
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from libs.oauth_bearer import Scope
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/files/upload")
|
||||
class AppFileUploadApi(Resource):
|
||||
@openapi_ns.doc("upload_file_for_app_input")
|
||||
@openapi_ns.doc(description="Upload a file to use as an input variable when running the app")
|
||||
@openapi_ns.doc(
|
||||
responses={
|
||||
201: "File uploaded successfully",
|
||||
400: "Bad request — no file or filename missing",
|
||||
401: "Unauthorized — invalid or expired bearer token",
|
||||
413: "File too large",
|
||||
415: "Unsupported file type or blocked extension",
|
||||
}
|
||||
)
|
||||
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, _ = auth_data.require_app_context()
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
file = request.files["file"]
|
||||
if not file.mimetype:
|
||||
raise UnsupportedFileTypeError()
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError()
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.stream.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=caller,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except services.errors.file.FileTooLargeError as exc:
|
||||
raise FileTooLargeError(exc.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
except services.errors.file.BlockedFileExtensionError as exc:
|
||||
raise BlockedFileExtensionError(exc.description)
|
||||
|
||||
response = FileResponse.model_validate(upload_file, from_attributes=True)
|
||||
return response.model_dump(mode="json"), 201
|
||||
@ -1,110 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed human input form endpoints.
|
||||
|
||||
GET /apps/<app_id>/form/human_input/<form_token> — fetch paused form definition
|
||||
POST /apps/<app_id>/form/human_input/<form_token> — submit form response
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import to_timestamp
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App
|
||||
from services.human_input_service import FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(openapi_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": stringify_form_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_openapi(form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.OPENAPI):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
|
||||
class OpenApiWorkflowHumanInputFormApi(Resource):
|
||||
@openapi_ns.response(200, "Form definition")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@openapi_ns.response(200, "Form submitted")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_openapi(form)
|
||||
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
if caller_kind == "account":
|
||||
submission_user_id = caller.id
|
||||
else:
|
||||
submission_end_user_id = caller.id
|
||||
|
||||
if form.recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_token=%s", form_token)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=form.recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_user_id=submission_user_id,
|
||||
submission_end_user_id=submission_end_user_id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
@ -1,9 +0,0 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
@ -1,398 +0,0 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
AccountPayload,
|
||||
DeviceCodeRequest,
|
||||
DeviceCodeResponse,
|
||||
DeviceLookupQuery,
|
||||
DeviceLookupResponse,
|
||||
DeviceMutateRequest,
|
||||
DeviceMutateResponse,
|
||||
DevicePollRequest,
|
||||
WorkspacePayload,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import TenantService
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Validation helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceCodeRequest.__name__])
|
||||
@openapi_ns.response(200, "Device code created", openapi_ns.models[DeviceCodeResponse.__name__])
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[DevicePollRequest.__name__])
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(DeviceLookupQuery))
|
||||
@openapi_ns.response(200, "Device lookup result", openapi_ns.models[DeviceLookupResponse.__name__])
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Approved", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
profile = MINTABLE_PROFILES[SubjectType.ACCOUNT]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@openapi_ns.expect(openapi_ns.models[DeviceMutateRequest.__name__])
|
||||
@openapi_ns.response(200, "Denied", openapi_ns.models[DeviceMutateResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(account.id))
|
||||
workspaces = [WorkspacePayload(id=str(t.id), name=t.name, role=getattr(m, "role", "")) for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w.id == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": AccountPayload(id=str(account.id), email=account.email, name=account.name).model_dump(mode="json"),
|
||||
"workspaces": [w.model_dump(mode="json") for w in workspaces],
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
@ -1,365 +0,0 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import bp
|
||||
from controllers.openapi._models import ExtSubjectAssertionClaims
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import MINTABLE_PROFILES, SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
from services.openapi.mint_policy import MintPolicyViolation, validate_mint_policy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
def _trusted_origin() -> str:
|
||||
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
|
||||
if not base:
|
||||
raise BadGateway("console_api_url_unset")
|
||||
return base
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
origin = _trusted_origin()
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
try:
|
||||
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
|
||||
except ValidationError as e:
|
||||
logger.warning("sso-complete: claim shape invalid: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = claims.user_code.strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.email):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = _trusted_origin()
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims.email,
|
||||
subject_issuer=claims.issuer,
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims.subject_email):
|
||||
_emit_external_rejection_audit(state, claims, reason="email_belongs_to_dify_account")
|
||||
raise Forbidden("email_belongs_to_dify_account")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
profile = MINTABLE_PROFILES[SubjectType.EXTERNAL_SSO]
|
||||
try:
|
||||
validate_mint_policy(
|
||||
subject_type=profile.subject_type,
|
||||
prefix=profile.prefix,
|
||||
scopes=profile.scopes,
|
||||
)
|
||||
except MintPolicyViolation as e:
|
||||
raise BadRequest(description=str(e)) from None
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=profile.prefix,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
# SSO branch of the shared PollPayload contract: account/workspace
|
||||
# fields are zero-filled (`None` / `[]`) for parity with the account
|
||||
# branch in `oauth_device._build_account_poll_payload`.
|
||||
poll_payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RejectedClaims:
|
||||
"""Minimal subject shape consumed by `_emit_external_rejection_audit`.
|
||||
|
||||
Mirrors the attributes used from `ApprovalGrantClaims` so callers holding
|
||||
only a raw JWS claims dict (e.g. `sso_complete`) can emit the same audit
|
||||
event without reaching for the full dataclass.
|
||||
"""
|
||||
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
|
||||
|
||||
def _emit_external_rejection_audit(state, claims, *, reason: str) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_rejected subject_type=%s subject_email=%s subject_issuer=%s reason=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
reason,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_rejected",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"reason": reason,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
@ -1,121 +0,0 @@
|
||||
"""
|
||||
OpenAPI bearer-authed workflow reconnect event stream endpoint.
|
||||
|
||||
GET /apps/<app_id>/tasks/<task_id>/events
|
||||
— reconnect to the SSE stream for a paused/running workflow run.
|
||||
`task_id` is treated as `workflow_run_id`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
|
||||
class OpenApiWorkflowEventsApi(Resource):
|
||||
@openapi_ns.response(200, "SSE event stream")
|
||||
@auth_router.guard(scope=Scope.APPS_RUN)
|
||||
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
|
||||
app_model, caller, caller_kind = auth_data.require_app_context()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if caller_kind == "account":
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
else:
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER or workflow_run.created_by != caller.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=caller,
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
else:
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.OPENAPI,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
@ -1,329 +0,0 @@
|
||||
"""User-scoped workspace reads and member management under /openapi/v1/workspaces.
|
||||
|
||||
Bearer-authed counterparts to the cookie-authed /console/api/workspaces
|
||||
endpoints. Account bearers (dfoa_) see every tenant they're a member of.
|
||||
External SSO bearers (dfoe_) have no account_id and so see an empty list —
|
||||
that matches /openapi/v1/account.
|
||||
|
||||
Member-management endpoints are gated by both `accept_subjects` (SSO out)
|
||||
and `require_workspace_role` (membership / role lookup against the path's
|
||||
``workspace_id``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
from urllib import parse
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import query_params_from_model
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MemberActionResponse,
|
||||
MemberInvitePayload,
|
||||
MemberInviteResponse,
|
||||
MemberListQuery,
|
||||
MemberListResponse,
|
||||
MemberResponse,
|
||||
MemberRoleUpdatePayload,
|
||||
WorkspaceDetailResponse,
|
||||
WorkspaceListResponse,
|
||||
WorkspaceSummaryResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import auth_router
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
from controllers.openapi.auth.role_gate import require_workspace_role
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import Scope, TokenType
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.account import TenantAccountRole, TenantStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
AccountNotLinkTenantError,
|
||||
AccountRegisterError,
|
||||
CannotOperateSelfError,
|
||||
MemberNotInTenantError,
|
||||
NoPermissionError,
|
||||
RoleAlreadyAssignedError,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _validate_body[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _member_response(account: Account) -> MemberResponse:
|
||||
return MemberResponse(
|
||||
id=str(account.id),
|
||||
name=account.name,
|
||||
email=account.email,
|
||||
role=account.role.value if account.role else "",
|
||||
status=account.status.value if account.status else "",
|
||||
avatar=account.avatar,
|
||||
)
|
||||
|
||||
|
||||
def _load_tenant(workspace_id: str) -> Tenant:
|
||||
tenant = TenantService.get_tenant_by_id(db.session, workspace_id)
|
||||
if tenant is None or tenant.status != TenantStatus.NORMAL:
|
||||
raise NotFound("workspace not found")
|
||||
return tenant
|
||||
|
||||
|
||||
def _load_account(account_id: object) -> Account:
|
||||
account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None
|
||||
if account is None:
|
||||
raise RuntimeError("authenticated account_id has no Account row")
|
||||
return account
|
||||
|
||||
|
||||
def _quota_error(*, code: str, message: str, hint: str) -> Forbidden:
|
||||
err = Forbidden(message)
|
||||
err.response = make_response(
|
||||
jsonify({"code": code, "message": message, "hint": hint}),
|
||||
403,
|
||||
)
|
||||
return err
|
||||
|
||||
|
||||
def _check_member_invite_quota(tenant_id: str) -> None:
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
if 0 < members.limit <= members.size:
|
||||
raise _quota_error(
|
||||
code="members.limit_exceeded",
|
||||
message="Subscription member limit reached.",
|
||||
hint="Upgrade your plan to invite more members or remove an existing member first.",
|
||||
)
|
||||
|
||||
if features.workspace_members.enabled:
|
||||
if not features.workspace_members.is_available(1):
|
||||
raise _quota_error(
|
||||
code="workspace_members.license_exceeded",
|
||||
message="Workspace member license capacity reached.",
|
||||
hint="Contact your workspace administrator to expand the license seat count.",
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, *, auth_data: AuthData):
|
||||
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
|
||||
|
||||
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/switch")
|
||||
class WorkspaceSwitchApi(Resource):
|
||||
"""Server-side switch — equivalent to the console's POST /workspaces/switch.
|
||||
|
||||
CLI `difyctl use workspace <id>` calls this; it does NOT mutate
|
||||
``hosts.yml`` on its own. Failure here must abort the local write so
|
||||
that ``hosts.yml`` never diverges from the server's ``current`` state.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
account = _load_account(auth_data.account_id)
|
||||
|
||||
try:
|
||||
TenantService.switch_tenant(account, workspace_id)
|
||||
except AccountNotLinkTenantError:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership).model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members")
|
||||
class WorkspaceMembersApi(Resource):
|
||||
"""List + invite members.
|
||||
|
||||
GET is any-member. POST requires admin/owner — owner can never be
|
||||
assigned through invite (ownership transfer is console-only).
|
||||
"""
|
||||
|
||||
@openapi_ns.doc(params=query_params_from_model(MemberListQuery))
|
||||
@openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role()
|
||||
def get(self, workspace_id: str, *, auth_data: AuthData):
|
||||
try:
|
||||
query = MemberListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
tenant = _load_tenant(workspace_id)
|
||||
members = TenantService.get_tenant_members(tenant)
|
||||
total = len(members)
|
||||
start = (query.page - 1) * query.limit
|
||||
page_items = members[start : start + query.limit]
|
||||
return MemberListResponse(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
total=total,
|
||||
has_more=query.page * query.limit < total,
|
||||
data=[_member_response(m) for m in page_items],
|
||||
).model_dump(mode="json"), 200
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__])
|
||||
@openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def post(self, workspace_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberInvitePayload)
|
||||
inviter = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
|
||||
_check_member_invite_quota(str(tenant.id))
|
||||
|
||||
try:
|
||||
token = RegisterService.invite_new_member(
|
||||
tenant=tenant,
|
||||
email=payload.email,
|
||||
language=None,
|
||||
role=payload.role,
|
||||
inviter=inviter,
|
||||
)
|
||||
except AccountAlreadyInTenantError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except AccountRegisterError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
normalized_email = payload.email.lower()
|
||||
member = AccountService.get_account_by_email_with_case_fallback(normalized_email)
|
||||
if member is None:
|
||||
# invite_new_member just created or fetched this account.
|
||||
raise RuntimeError("invited member missing from DB after invite")
|
||||
|
||||
encoded_email = parse.quote(normalized_email)
|
||||
invite_url = f"{dify_config.CONSOLE_WEB_URL}/activate?email={encoded_email}&token={token}"
|
||||
return MemberInviteResponse(
|
||||
email=normalized_email,
|
||||
role=payload.role,
|
||||
member_id=str(member.id),
|
||||
invite_url=invite_url,
|
||||
tenant_id=str(tenant.id),
|
||||
).model_dump(mode="json"), 201
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>")
|
||||
class WorkspaceMemberApi(Resource):
|
||||
"""Remove a member.
|
||||
|
||||
Self-removal and owner-removal are explicitly rejected by the service
|
||||
layer (CannotOperateSelfError, NoPermissionError) — both surface as
|
||||
400 per the spec, with the service's message preserved.
|
||||
"""
|
||||
|
||||
@openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.remove_member_from_tenant(tenant, member, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>/members/<string:member_id>/role")
|
||||
class WorkspaceMemberRoleApi(Resource):
|
||||
"""Change a member's role.
|
||||
|
||||
Owner cannot be assigned here (closed enum). Admin cannot demote the
|
||||
standing owner (service NoPermissionError → 400, per spec).
|
||||
"""
|
||||
|
||||
@openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__])
|
||||
@openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__])
|
||||
@auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
|
||||
@require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN)
|
||||
def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData):
|
||||
payload = _validate_body(MemberRoleUpdatePayload)
|
||||
operator = _load_account(auth_data.account_id)
|
||||
tenant = _load_tenant(workspace_id)
|
||||
member = AccountService.get_account_by_id(db.session, member_id)
|
||||
if member is None:
|
||||
raise NotFound("member not found")
|
||||
|
||||
try:
|
||||
TenantService.update_member_role(tenant, member, payload.role, operator)
|
||||
except CannotOperateSelfError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except NoPermissionError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
except MemberNotInTenantError as exc:
|
||||
raise NotFound(str(exc))
|
||||
except RoleAlreadyAssignedError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
return MemberActionResponse().model_dump(mode="json"), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceSummaryResponse:
|
||||
return WorkspaceSummaryResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
)
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> WorkspaceDetailResponse:
|
||||
return WorkspaceDetailResponse(
|
||||
id=str(tenant.id),
|
||||
name=tenant.name,
|
||||
role=getattr(membership, "role", ""),
|
||||
status=tenant.status,
|
||||
current=getattr(membership, "current", False),
|
||||
created_at=tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
)
|
||||
@ -174,11 +174,11 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: UUID):
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -195,7 +195,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: UUID):
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
return "", 204
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -51,20 +50,20 @@ class FilePreviewApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API.
|
||||
|
||||
Provides secure file preview/download functionality.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
file_id_str = str(file_id)
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
from uuid import UUID
|
||||
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.hit_testing_fields import HitTestingResponse
|
||||
from libs.helper import dump_response
|
||||
|
||||
register_schema_models(service_api_ns, HitTestingPayload)
|
||||
register_response_schema_models(service_api_ns, HitTestingResponse)
|
||||
register_schema_model(service_api_ns, HitTestingPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
|
||||
@ -16,16 +13,16 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@service_api_ns.doc("dataset_hit_testing")
|
||||
@service_api_ns.doc(description="Perform hit testing on a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Hit testing results",
|
||||
model=service_api_ns.models[HitTestingResponse.__name__],
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Hit testing results",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(401, "Unauthorized - invalid API token")
|
||||
@service_api_ns.response(404, "Dataset not found")
|
||||
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID) -> dict[str, object]:
|
||||
def post(self, tenant_id, dataset_id: UUID):
|
||||
"""Perform hit testing on a dataset.
|
||||
|
||||
Tests retrieval performance for the specified dataset.
|
||||
@ -36,4 +33,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
args = self.parse_args(service_api_ns.payload)
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args))
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
@ -65,11 +64,10 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID):
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -79,7 +77,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
@ -111,11 +109,10 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: UUID, node_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -123,7 +120,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
|
||||
{
|
||||
**payload.model_dump(exclude_none=True),
|
||||
@ -175,11 +172,10 @@ class PipelineRunApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -190,7 +186,7 @@ class PipelineRunApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
@ -108,19 +107,17 @@ class SegmentApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
@ -153,10 +150,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@ -171,21 +165,19 @@ class SegmentApi(DatasetApiResource):
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
@ -213,7 +205,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id_str,
|
||||
document_id=document_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args.status,
|
||||
keyword=args.keyword,
|
||||
@ -222,7 +214,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@ -248,25 +240,22 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
@ -287,20 +276,18 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@ -319,19 +306,15 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {
|
||||
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@ -342,29 +325,26 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -389,26 +369,23 @@ class ChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
@ -452,26 +429,23 @@ class ChildChunkApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
@ -487,9 +461,7 @@ class ChildChunkApi(DatasetApiResource):
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
@ -525,38 +497,32 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -592,38 +558,32 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# get document
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
if str(segment.document_id) != str(document_id):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
@ -141,26 +140,20 @@ def cloud_edition_billing_resource_check[**P, R](
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
if resource == "vector_space":
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
vector_space = FeatureService.get_vector_space(api_token.tenant_id)
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
members = features.members
|
||||
apps = features.apps
|
||||
vector_space = features.vector_space
|
||||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
@ -181,7 +174,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
features = FeatureService.get_features(api_token.tenant_id, exclude_vector_space=True)
|
||||
features = FeatureService.get_features(api_token.tenant_id)
|
||||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == CloudPlan.SANDBOX:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user