Compare commits

..

1 Commits

Author SHA1 Message Date
f5528f2030 Initial plan 2025-11-18 16:22:22 +00:00
1080 changed files with 15228 additions and 104425 deletions

View File

@ -1,6 +0,0 @@
# Cursor Rules for Dify Project
## Automated Test Generation
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -29,7 +29,7 @@ trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation # Matches multiple files with brace expansion notation
# Set default charset # Set default charset
[*.{js,jsx,ts,tsx,mjs}] [*.{js,tsx}]
indent_style = space indent_style = space
indent_size = 2 indent_size = 2

226
.github/CODEOWNERS vendored
View File

@ -1,226 +0,0 @@
# CODEOWNERS
# This file defines code ownership for the Dify project.
# Each line is a file pattern followed by one or more owners.
# Owners can be @username, @org/team-name, or email addresses.
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
* @crazywoola @laipz8200 @Yeuoly
# Backend (default owner, more specific rules below will override)
api/ @QuantumGhost
# Backend - Workflow - Engine (Core graph execution engine)
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
api/core/workflow/graph/ @laipz8200 @QuantumGhost
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
api/core/model_runtime/ @laipz8200 @QuantumGhost
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
api/core/workflow/nodes/agent/ @Nov1c444
api/core/workflow/nodes/iteration/ @Nov1c444
api/core/workflow/nodes/loop/ @Nov1c444
api/core/workflow/nodes/llm/ @Nov1c444
# Backend - RAG (Retrieval Augmented Generation)
api/core/rag/ @JohnJyong
api/services/rag_pipeline/ @JohnJyong
api/services/dataset_service.py @JohnJyong
api/services/knowledge_service.py @JohnJyong
api/services/external_knowledge_service.py @JohnJyong
api/services/hit_testing_service.py @JohnJyong
api/services/metadata_service.py @JohnJyong
api/services/vector_service.py @JohnJyong
api/services/entities/knowledge_entities/ @JohnJyong
api/services/entities/external_knowledge_entities/ @JohnJyong
api/controllers/console/datasets/ @JohnJyong
api/controllers/service_api/dataset/ @JohnJyong
api/models/dataset.py @JohnJyong
api/tasks/rag_pipeline/ @JohnJyong
api/tasks/add_document_to_index_task.py @JohnJyong
api/tasks/batch_clean_document_task.py @JohnJyong
api/tasks/clean_document_task.py @JohnJyong
api/tasks/clean_notion_document_task.py @JohnJyong
api/tasks/document_indexing_task.py @JohnJyong
api/tasks/document_indexing_sync_task.py @JohnJyong
api/tasks/document_indexing_update_task.py @JohnJyong
api/tasks/duplicate_document_indexing_task.py @JohnJyong
api/tasks/recover_document_indexing_task.py @JohnJyong
api/tasks/remove_document_from_index_task.py @JohnJyong
api/tasks/retry_document_indexing_task.py @JohnJyong
api/tasks/sync_website_document_indexing_task.py @JohnJyong
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
api/tasks/create_segment_to_index_task.py @JohnJyong
api/tasks/delete_segment_from_index_task.py @JohnJyong
api/tasks/disable_segment_from_index_task.py @JohnJyong
api/tasks/disable_segments_from_index_task.py @JohnJyong
api/tasks/enable_segment_to_index_task.py @JohnJyong
api/tasks/enable_segments_to_index_task.py @JohnJyong
api/tasks/clean_dataset_task.py @JohnJyong
api/tasks/deal_dataset_index_update_task.py @JohnJyong
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
# Backend - Plugins
api/core/plugin/ @Mairuis @Yeuoly @Stream29
api/services/plugin/ @Mairuis @Yeuoly @Stream29
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
# Backend - Trigger/Schedule/Webhook
api/controllers/trigger/ @Mairuis @Yeuoly
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
api/core/trigger/ @Mairuis @Yeuoly
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
api/services/trigger/ @Mairuis @Yeuoly
api/models/trigger.py @Mairuis @Yeuoly
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
api/libs/schedule_utils.py @Mairuis @Yeuoly
api/services/workflow/scheduler.py @Mairuis @Yeuoly
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
# Backend - Async Workflow
api/services/async_workflow_service.py @Mairuis @Yeuoly
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
# Backend - Billing
api/services/billing_service.py @hj24 @zyssyz123
api/controllers/console/billing/ @hj24 @zyssyz123
# Backend - Enterprise
api/configs/enterprise/ @GarfieldDai @GareArc
api/services/enterprise/ @GarfieldDai @GareArc
api/services/feature_service.py @GarfieldDai @GareArc
api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200
# Frontend
web/ @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
web/app/components/app/configuration/ @iamjoel @zxhlyh
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
# Frontend - WebApp - Chat
web/app/components/base/chat/ @iamjoel @zxhlyh
# Frontend - WebApp - Completion
web/app/components/share/text-generation/ @iamjoel @zxhlyh
# Frontend - App - List and Creation
web/app/components/apps/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
# Frontend - App - API Documentation
web/app/components/develop/ @JzoNgKVO @iamjoel
# Frontend - App - Logs and Annotations
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
web/app/components/app/log/ @JzoNgKVO @iamjoel
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
# Frontend - App - Monitoring
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
web/app/components/app/overview/ @JzoNgKVO @iamjoel
# Frontend - App - Settings
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
# Frontend - RAG - Hit Testing
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
# Frontend - RAG - List and Creation
web/app/components/datasets/list/ @iamjoel @WTW0313
web/app/components/datasets/create/ @iamjoel @WTW0313
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
web/app/components/rag-pipeline/ @iamjoel @WTW0313
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
# Frontend - RAG - Documents List
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
# Frontend - RAG - Segments List
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
# Frontend - RAG - Settings
web/app/components/datasets/settings/ @iamjoel @WTW0313
# Frontend - Ecosystem - Plugins
web/app/components/plugins/ @iamjoel @zhsama
# Frontend - Ecosystem - Tools
web/app/components/tools/ @iamjoel @Yessenia-d
# Frontend - Ecosystem - MarketPlace
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
# Frontend - Login and Registration
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
web/app/account/ @douxc @iamjoel
# Frontend - Service Authentication
web/service/base.ts @douxc @iamjoel
# Frontend - WebApp Authentication and Access Control
web/app/(shareLayout)/components/ @douxc @iamjoel
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
web/app/components/app/app-access-control/ @douxc @iamjoel
# Frontend - Explore Page
web/app/components/explore/ @CodingOnStar @iamjoel
# Frontend - Personal Settings
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
# Frontend - Analytics
web/app/components/base/ga/ @CodingOnStar @iamjoel
# Frontend - Base Components
web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Utils and Hooks
web/utils/classnames.ts @iamjoel @zxhlyh
web/utils/time.ts @iamjoel @zxhlyh
web/utils/format.ts @iamjoel @zxhlyh
web/utils/clipboard.ts @iamjoel @zxhlyh
web/hooks/use-document-title.ts @iamjoel @zxhlyh
# Frontend - Billing and Education
web/app/components/billing/ @iamjoel @zxhlyh
web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh

View File

@ -1,12 +0,0 @@
# Copilot Instructions
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
Key reminders:
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
- Target >95% line and branch coverage and 100% function/statement coverage.
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.

View File

@ -62,7 +62,7 @@ jobs:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
services: | services: |
db_postgres db
redis redis
sandbox sandbox
ssrf_proxy ssrf_proxy

View File

@ -8,7 +8,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
db-migration-test-postgres: db-migration-test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@ -45,7 +45,7 @@ jobs:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
services: | services: |
db_postgres db
redis redis
- name: Prepare configs - name: Prepare configs
@ -57,60 +57,3 @@ jobs:
env: env:
DEBUG: true DEBUG: true
run: uv run --directory api flask upgrade-db run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env for MySQL
run: |
cd docker
cp middleware.env.example middleware.env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db_mysql
redis
- name: Prepare configs for MySQL
run: |
cd api
cp .env.example .env
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
- name: Run DB Migration
env:
DEBUG: true
run: uv run --directory api flask upgrade-db

View File

@ -20,22 +20,22 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 2
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US - name: Check for file changes in i18n/en-US
id: check_files id: check_files
run: | run: |
git fetch origin "${{ github.event.before }}" || true recent_commit_sha=$(git rev-parse HEAD)
git fetch origin "${{ github.sha }}" || true second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts') changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files" echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args="" file_args=""
for file in $changed_files; do for file in $changed_files; do
filename=$(basename "$file" .ts) filename=$(basename "$file" .ts)
file_args="$file_args --file $filename" file_args="$file_args --file=$filename"
done done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args" echo "File arguments: $file_args"
@ -77,15 +77,12 @@ jobs:
uses: peter-evans/create-pull-request@v6 uses: peter-evans/create-pull-request@v6
with: with:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes' commit-message: Update i18n files and type definitions based on en-US changes
title: 'chore(i18n): translate i18n files and update type definitions' title: 'chore: translate i18n files and update type definitions'
body: | body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale. This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:** **Changes included:**
- Updated translation files for all locales - Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety - Regenerated TypeScript type definitions for type safety
branch: chore/automated-i18n-updates-${{ github.sha }} branch: chore/automated-i18n-updates
delete-branch: true

View File

@ -1,7 +1,10 @@
name: Run VDB Tests name: Run VDB Tests
on: on:
workflow_call: push:
branches: [main]
paths:
- 'api/core/rag/*.py'
concurrency: concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }} group: vdb-tests-${{ github.head_ref || github.run_id }}
@ -51,13 +54,13 @@ jobs:
- name: Expose Service Ports - name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB) - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2 uses: hoverkraft-tech/compose-action@v2.0.2
# with: with:
# compose-file: docker/tidb/docker-compose.yaml compose-file: docker/tidb/docker-compose.yaml
# services: | services: |
# tidb tidb
# tiflash tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase) - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2 uses: hoverkraft-tech/compose-action@v2.0.2
@ -83,8 +86,8 @@ jobs:
ls -lah . ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB) - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

2
.gitignore vendored
View File

@ -186,8 +186,6 @@ docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/* docker/volumes/plugin_daemon/*
docker/volumes/matrixone/* docker/volumes/matrixone/*
docker/volumes/mysql/*
docker/volumes/seekdb/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf

View File

@ -37,7 +37,7 @@
"-c", "-c",
"1", "1",
"-Q", "-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor", "dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline",
"--loglevel", "--loglevel",
"INFO" "INFO"
], ],

View File

@ -1,5 +0,0 @@
# Windsurf Testing Rules
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
- Honor every requirement in that document when generating or accepting tests.
- When proposing or saving tests, re-read that document and follow every requirement.

View File

@ -77,8 +77,6 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly. For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend #### Backend
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly. For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.

View File

@ -70,11 +70,6 @@ type-check:
@uv run --directory api --dev basedpyright @uv run --directory api --dev basedpyright
@echo "✅ Type check complete" @echo "✅ Type check complete"
test:
@echo "🧪 Running backend unit tests..."
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
@echo "✅ Tests complete"
# Build Docker images # Build Docker images
build-web: build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..." @echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -124,7 +119,6 @@ help:
@echo " make check - Check code with ruff" @echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff" @echo " make lint - Format and fix code with ruff"
@echo " make type-check - Run type checking with basedpyright" @echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo "" @echo ""
@echo "Docker Build Targets:" @echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image" @echo " make build-web - Build web Docker image"
@ -134,4 +128,4 @@ help:
@echo " make build-push-all - Build and push all Docker images" @echo " make build-push-all - Build and push all Docker images"
# Phony targets # Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test .PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check

View File

@ -36,12 +36,6 @@
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a> <img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
<a href="https://github.com/langgenius/dify/discussions/" target="_blank"> <a href="https://github.com/langgenius/dify/discussions/" target="_blank">
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a> <img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
</p> </p>
<p align="center"> <p align="center">

View File

@ -72,15 +72,12 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis CELERY_BACKEND=redis
# PostgreSQL database configuration
# Database configuration
DB_TYPE=postgresql
DB_USERNAME=postgres DB_USERNAME=postgres
DB_PASSWORD=difyai123456 DB_PASSWORD=difyai123456
DB_HOST=localhost DB_HOST=localhost
DB_PORT=5432 DB_PORT=5432
DB_DATABASE=dify DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true SQLALCHEMY_POOL_PRE_PING=true
SQLALCHEMY_POOL_TIMEOUT=30 SQLALCHEMY_POOL_TIMEOUT=30
@ -166,7 +163,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
COOKIE_DOMAIN= COOKIE_DOMAIN=
# Vector database configuration # Vector database configuration
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database # Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index VECTOR_INDEX_NAME_PREFIX=Vector_index
@ -176,18 +173,6 @@ WEAVIATE_ENDPOINT=http://localhost:8080
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100 WEAVIATE_BATCH_SIZE=100
WEAVIATE_TOKENIZATION=word
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
OCEANBASE_FULLTEXT_PARSER=ik
SEEKDB_MEMORY_LIMIT=2G
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode # Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333 QDRANT_URL=http://localhost:6333
@ -354,6 +339,15 @@ LINDORM_PASSWORD=admin
LINDORM_USING_UGC=True LINDORM_USING_UGC=True
LINDORM_QUERY_TIMEOUT=1 LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=difyai123456
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
OCEANBASE_ENABLE_HYBRID_SEARCH=false
# AlibabaCloud MySQL Vector configuration # AlibabaCloud MySQL Vector configuration
ALIBABACLOUD_MYSQL_HOST=127.0.0.1 ALIBABACLOUD_MYSQL_HOST=127.0.0.1
ALIBABACLOUD_MYSQL_PORT=3306 ALIBABACLOUD_MYSQL_PORT=3306
@ -540,7 +534,6 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_DEFAULT_ACTIVE_REQUESTS=0
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration # Celery beat configuration

View File

@ -16,7 +16,6 @@ layers =
graph graph
nodes nodes
node_events node_events
runtime
entities entities
containers = containers =
core.workflow core.workflow

View File

@ -48,12 +48,6 @@ ENV PYTHONIOENCODING=utf-8
WORKDIR /app/api WORKDIR /app/api
# Create non-root user
ARG dify_uid=1001
RUN groupadd -r -g ${dify_uid} dify && \
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
chown -R dify:dify /app
RUN \ RUN \
apt-get update \ apt-get update \
# Install dependencies # Install dependencies
@ -63,7 +57,7 @@ RUN \
# for gmpy2 \ # for gmpy2 \
libgmp-dev libmpfr-dev libmpc-dev \ libgmp-dev libmpfr-dev libmpc-dev \
# For Security # For Security
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \ expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2 # install fonts to support the use of tools like pypdfium2
fonts-noto-cjk \ fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension # install a package to improve the accuracy of guessing mime type and file extension
@ -75,29 +69,24 @@ RUN \
# Copy Python environment and packages # Copy Python environment and packages
ENV VIRTUAL_ENV=/app/api/.venv ENV VIRTUAL_ENV=/app/api/.venv
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV} COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data # Download nltk data
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \ RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
# Copy source code # Copy source code
COPY --chown=dify:dify . /app/api/ COPY . /app/api/
# Prepare entrypoint script
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
# Copy entrypoint
COPY docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
ARG COMMIT_SHA ARG COMMIT_SHA
ENV COMMIT_SHA=${COMMIT_SHA} ENV COMMIT_SHA=${COMMIT_SHA}
ENV NLTK_DATA=/usr/local/share/nltk_data
USER dify
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"] ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]

View File

@ -15,8 +15,8 @@
```bash ```bash
cd ../docker cd ../docker
cp middleware.env.example middleware.env cp middleware.env.example middleware.env
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate # change the profile to other vector database if you are not using weaviate
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
cd ../api cd ../api
``` ```
@ -84,7 +84,7 @@
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,priority_pipeline,pipeline
``` ```
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service: Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:

View File

@ -18,7 +18,6 @@ def create_flask_app_with_configs() -> DifyApp:
""" """
dify_app = DifyApp(__name__) dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump()) dify_app.config.from_mapping(dify_config.model_dump())
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
# add before request hook # add before request hook
@dify_app.before_request @dify_app.before_request
@ -51,7 +50,6 @@ def initialize_extensions(app: DifyApp):
ext_commands, ext_commands,
ext_compress, ext_compress,
ext_database, ext_database,
ext_forward_refs,
ext_hosting_provider, ext_hosting_provider,
ext_import_modules, ext_import_modules,
ext_logging, ext_logging,
@ -76,7 +74,6 @@ def initialize_extensions(app: DifyApp):
ext_warnings, ext_warnings,
ext_import_modules, ext_import_modules,
ext_orjson, ext_orjson,
ext_forward_refs,
ext_set_secretkey, ext_set_secretkey,
ext_compress, ext_compress,
ext_code_based_extension, ext_code_based_extension,

View File

@ -73,14 +73,14 @@ class AppExecutionConfig(BaseSettings):
description="Maximum allowed execution time for the application in seconds", description="Maximum allowed execution time for the application in seconds",
default=1200, default=1200,
) )
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Default number of concurrent active requests per app (0 for unlimited)",
default=0,
)
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field( APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
description="Maximum number of concurrent active requests per app (0 for unlimited)", description="Maximum number of concurrent active requests per app (0 for unlimited)",
default=0, default=0,
) )
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
description="Maximum number of requests per app per day",
default=5000,
)
class CodeExecutionSandboxConfig(BaseSettings): class CodeExecutionSandboxConfig(BaseSettings):
@ -1086,7 +1086,7 @@ class CeleryScheduleTasksConfig(BaseSettings):
) )
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field( TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
description="Proactive credential refresh threshold in seconds", description="Proactive credential refresh threshold in seconds",
default=60 * 60, default=180,
) )
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field( TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
description="Proactive subscription refresh threshold in seconds", description="Proactive subscription refresh threshold in seconds",

View File

@ -105,12 +105,6 @@ class KeywordStoreConfig(BaseSettings):
class DatabaseConfig(BaseSettings): class DatabaseConfig(BaseSettings):
# Database type selector
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
description="Database type to use. OceanBase is MySQL-compatible.",
default="postgresql",
)
DB_HOST: str = Field( DB_HOST: str = Field(
description="Hostname or IP address of the database server.", description="Hostname or IP address of the database server.",
default="localhost", default="localhost",
@ -146,10 +140,10 @@ class DatabaseConfig(BaseSettings):
default="", default="",
) )
@computed_field # type: ignore[prop-decorator] SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
@property description="Database URI scheme for SQLAlchemy connection.",
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str: default="postgresql",
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql" )
@computed_field # type: ignore[prop-decorator] @computed_field # type: ignore[prop-decorator]
@property @property
@ -210,15 +204,15 @@ class DatabaseConfig(BaseSettings):
# Parse DB_EXTRAS for 'options' # Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "") options = db_extras_dict.get("options", "")
connect_args = {} # Always include timezone
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property timezone_opt = "-c timezone=UTC"
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"): if options:
timezone_opt = "-c timezone=UTC" # Merge user options and timezone
if options: merged_options = f"{options} {timezone_opt}"
merged_options = f"{options} {timezone_opt}" else:
else: merged_options = timezone_opt
merged_options = timezone_opt
connect_args = {"options": merged_options} connect_args = {"options": merged_options}
return { return {
"pool_size": self.SQLALCHEMY_POOL_SIZE, "pool_size": self.SQLALCHEMY_POOL_SIZE,

View File

@ -31,8 +31,3 @@ class WeaviateConfig(BaseSettings):
description="Number of objects to be processed in a single batch operation (default is 100)", description="Number of objects to be processed in a single batch operation (default is 100)",
default=100, default=100,
) )
WEAVIATE_TOKENIZATION: str | None = Field(
description="Tokenization for Weaviate (default is word)",
default="word",
)

View File

@ -12,7 +12,7 @@ P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db from extensions.ext_database import db
from libs.token import extract_access_token from libs.token import extract_access_token
@ -38,10 +38,10 @@ def admin_required(view: Callable[P, R]):
@console_ns.route("/admin/insert-explore-apps") @console_ns.route("/admin/insert-explore-apps")
class InsertExploreAppListApi(Resource): class InsertExploreAppListApi(Resource):
@console_ns.doc("insert_explore_app") @api.doc("insert_explore_app")
@console_ns.doc(description="Insert or update an app in the explore list") @api.doc(description="Insert or update an app in the explore list")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"InsertExploreAppRequest", "InsertExploreAppRequest",
{ {
"app_id": fields.String(required=True, description="Application ID"), "app_id": fields.String(required=True, description="Application ID"),
@ -55,9 +55,9 @@ class InsertExploreAppListApi(Resource):
}, },
) )
) )
@console_ns.response(200, "App updated successfully") @api.response(200, "App updated successfully")
@console_ns.response(201, "App inserted successfully") @api.response(201, "App inserted successfully")
@console_ns.response(404, "App not found") @api.response(404, "App not found")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def post(self): def post(self):
@ -131,10 +131,10 @@ class InsertExploreAppListApi(Resource):
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>") @console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
class InsertExploreAppApi(Resource): class InsertExploreAppApi(Resource):
@console_ns.doc("delete_explore_app") @api.doc("delete_explore_app")
@console_ns.doc(description="Remove an app from the explore list") @api.doc(description="Remove an app from the explore list")
@console_ns.doc(params={"app_id": "Application ID to remove"}) @api.doc(params={"app_id": "Application ID to remove"})
@console_ns.response(204, "App removed successfully") @api.response(204, "App removed successfully")
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):

View File

@ -11,7 +11,7 @@ from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset from models.dataset import Dataset
from models.model import ApiToken, App from models.model import ApiToken, App
from . import console_ns from . import api, console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = { api_key_fields = {
@ -24,12 +24,6 @@ api_key_fields = {
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
def _get_resource(resource_id, tenant_id, resource_model): def _get_resource(resource_id, tenant_id, resource_model):
if resource_model == App: if resource_model == App:
@ -58,7 +52,7 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None token_prefix: str | None = None
max_keys = 10 max_keys = 10
@marshal_with(api_key_list_model) @marshal_with(api_key_list)
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
@ -72,7 +66,7 @@ class BaseApiKeyListResource(Resource):
).all() ).all()
return {"items": keys} return {"items": keys}
@marshal_with(api_key_item_model) @marshal_with(api_key_fields)
@edit_permission_required @edit_permission_required
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
@ -110,11 +104,14 @@ class BaseApiKeyResource(Resource):
resource_model: type | None = None resource_model: type | None = None
resource_id_field: str | None = None resource_id_field: str | None = None
def delete(self, resource_id: str, api_key_id: str): def delete(self, resource_id, api_key_id):
assert self.resource_id_field is not None, "resource_id_field must be set" assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
api_key_id = str(api_key_id)
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model) _get_resource(resource_id, current_tenant_id, self.resource_model)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
@ -139,20 +136,20 @@ class BaseApiKeyResource(Resource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys") @console_ns.route("/apps/<uuid:resource_id>/api-keys")
class AppApiKeyListResource(BaseApiKeyListResource): class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys") @api.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app") @api.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"}) @api.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "Success", api_key_list_model) @api.response(200, "Success", api_key_list)
def get(self, resource_id): # type: ignore def get(self, resource_id):
"""Get all API keys for an app""" """Get all API keys for an app"""
return super().get(resource_id) return super().get(resource_id)
@console_ns.doc("create_app_api_key") @api.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app") @api.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"}) @api.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model) @api.response(201, "API key created successfully", api_key_fields)
@console_ns.response(400, "Maximum keys exceeded") @api.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore def post(self, resource_id):
"""Create a new API key for an app""" """Create a new API key for an app"""
return super().post(resource_id) return super().post(resource_id)
@ -164,10 +161,10 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>") @console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class AppApiKeyResource(BaseApiKeyResource): class AppApiKeyResource(BaseApiKeyResource):
@console_ns.doc("delete_app_api_key") @api.doc("delete_app_api_key")
@console_ns.doc(description="Delete an API key for an app") @api.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"}) @api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully") @api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
"""Delete an API key for an app""" """Delete an API key for an app"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)
@ -179,20 +176,20 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys") @console_ns.route("/datasets/<uuid:resource_id>/api-keys")
class DatasetApiKeyListResource(BaseApiKeyListResource): class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys") @api.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset") @api.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"}) @api.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "Success", api_key_list_model) @api.response(200, "Success", api_key_list)
def get(self, resource_id): # type: ignore def get(self, resource_id):
"""Get all API keys for a dataset""" """Get all API keys for a dataset"""
return super().get(resource_id) return super().get(resource_id)
@console_ns.doc("create_dataset_api_key") @api.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset") @api.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"}) @api.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", api_key_item_model) @api.response(201, "API key created successfully", api_key_fields)
@console_ns.response(400, "Maximum keys exceeded") @api.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore def post(self, resource_id):
"""Create a new API key for a dataset""" """Create a new API key for a dataset"""
return super().post(resource_id) return super().post(resource_id)
@ -204,10 +201,10 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>") @console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
class DatasetApiKeyResource(BaseApiKeyResource): class DatasetApiKeyResource(BaseApiKeyResource):
@console_ns.doc("delete_dataset_api_key") @api.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete an API key for a dataset") @api.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"}) @api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully") @api.response(204, "API key deleted successfully")
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset""" """Delete an API key for a dataset"""
return super().delete(resource_id, api_key_id) return super().delete(resource_id, api_key_id)

View File

@ -1,39 +1,32 @@
from flask import request from flask_restx import Resource, fields, reqparse
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService from services.advanced_prompt_template_service import AdvancedPromptTemplateService
parser = (
class AdvancedPromptTemplateQuery(BaseModel): reqparse.RequestParser()
app_mode: str = Field(..., description="Application mode") .add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
model_mode: str = Field(..., description="Model mode") .add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
has_context: str = Field(default="true", description="Whether has context") .add_argument("has_context", type=str, required=False, default="true", location="args", help="Whether has context")
model_name: str = Field(..., description="Model name") .add_argument("model_name", type=str, required=True, location="args", help="Model name")
console_ns.schema_model(
AdvancedPromptTemplateQuery.__name__,
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
) )
@console_ns.route("/app/prompt-templates") @console_ns.route("/app/prompt-templates")
class AdvancedPromptTemplateList(Resource): class AdvancedPromptTemplateList(Resource):
@console_ns.doc("get_advanced_prompt_templates") @api.doc("get_advanced_prompt_templates")
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration") @api.doc(description="Get advanced prompt templates based on app mode and model configuration")
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data")) 200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
) )
@console_ns.response(400, "Invalid request parameters") @api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args.model_dump()) return AdvancedPromptTemplateService.get_prompt(args)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value from libs.helper import uuid_value
@ -17,14 +17,12 @@ parser = (
@console_ns.route("/apps/<uuid:app_id>/agent/logs") @console_ns.route("/apps/<uuid:app_id>/agent/logs")
class AgentLogApi(Resource): class AgentLogApi(Resource):
@console_ns.doc("get_agent_logs") @api.doc("get_agent_logs")
@console_ns.doc(description="Get agent execution logs for an application") @api.doc(description="Get agent execution logs for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(parser) @api.expect(parser)
@console_ns.response( @api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")) @api.response(400, "Invalid request parameters")
)
@console_ns.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, fields, marshal, marshal_with, reqparse from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from controllers.common.errors import NoFileUploadedError, TooManyFilesError from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
@ -15,7 +15,6 @@ from extensions.ext_redis import redis_client
from fields.annotation_fields import ( from fields.annotation_fields import (
annotation_fields, annotation_fields,
annotation_hit_history_fields, annotation_hit_history_fields,
build_annotation_model,
) )
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.login import login_required from libs.login import login_required
@ -24,11 +23,11 @@ from services.annotation_service import AppAnnotationService
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
class AnnotationReplyActionApi(Resource): class AnnotationReplyActionApi(Resource):
@console_ns.doc("annotation_reply_action") @api.doc("annotation_reply_action")
@console_ns.doc(description="Enable or disable annotation reply for an app") @api.doc(description="Enable or disable annotation reply for an app")
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"}) @api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"AnnotationReplyActionRequest", "AnnotationReplyActionRequest",
{ {
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"), "score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
@ -37,8 +36,8 @@ class AnnotationReplyActionApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Action completed successfully") @api.response(200, "Action completed successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -62,11 +61,11 @@ class AnnotationReplyActionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-setting") @console_ns.route("/apps/<uuid:app_id>/annotation-setting")
class AppAnnotationSettingDetailApi(Resource): class AppAnnotationSettingDetailApi(Resource):
@console_ns.doc("get_annotation_setting") @api.doc("get_annotation_setting")
@console_ns.doc(description="Get annotation settings for an app") @api.doc(description="Get annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Annotation settings retrieved successfully") @api.response(200, "Annotation settings retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -79,11 +78,11 @@ class AppAnnotationSettingDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>") @console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
class AppAnnotationSettingUpdateApi(Resource): class AppAnnotationSettingUpdateApi(Resource):
@console_ns.doc("update_annotation_setting") @api.doc("update_annotation_setting")
@console_ns.doc(description="Update annotation settings for an app") @api.doc(description="Update annotation settings for an app")
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"}) @api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"AnnotationSettingUpdateRequest", "AnnotationSettingUpdateRequest",
{ {
"score_threshold": fields.Float(required=True, description="Score threshold"), "score_threshold": fields.Float(required=True, description="Score threshold"),
@ -92,8 +91,8 @@ class AppAnnotationSettingUpdateApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Settings updated successfully") @api.response(200, "Settings updated successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -111,11 +110,11 @@ class AppAnnotationSettingUpdateApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>") @console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
class AnnotationReplyActionStatusApi(Resource): class AnnotationReplyActionStatusApi(Resource):
@console_ns.doc("get_annotation_reply_action_status") @api.doc("get_annotation_reply_action_status")
@console_ns.doc(description="Get status of annotation reply action job") @api.doc(description="Get status of annotation reply action job")
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"}) @api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
@console_ns.response(200, "Job status retrieved successfully") @api.response(200, "Job status retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -139,17 +138,17 @@ class AnnotationReplyActionStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations") @console_ns.route("/apps/<uuid:app_id>/annotations")
class AnnotationApi(Resource): class AnnotationApi(Resource):
@console_ns.doc("list_annotations") @api.doc("list_annotations")
@console_ns.doc(description="Get annotations for an app with pagination") @api.doc(description="Get annotations for an app with pagination")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.parser() api.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number") .add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size") .add_argument("limit", type=int, location="args", default=20, help="Page size")
.add_argument("keyword", type=str, location="args", default="", help="Search keyword") .add_argument("keyword", type=str, location="args", default="", help="Search keyword")
) )
@console_ns.response(200, "Annotations retrieved successfully") @api.response(200, "Annotations retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -170,11 +169,11 @@ class AnnotationApi(Resource):
} }
return response, 200 return response, 200
@console_ns.doc("create_annotation") @api.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app") @api.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"CreateAnnotationRequest", "CreateAnnotationRequest",
{ {
"message_id": fields.String(description="Message ID (optional)"), "message_id": fields.String(description="Message ID (optional)"),
@ -185,8 +184,8 @@ class AnnotationApi(Resource):
}, },
) )
) )
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns)) @api.response(201, "Annotation created successfully", annotation_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -236,15 +235,11 @@ class AnnotationApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/export") @console_ns.route("/apps/<uuid:app_id>/annotations/export")
class AnnotationExportApi(Resource): class AnnotationExportApi(Resource):
@console_ns.doc("export_annotations") @api.doc("export_annotations")
@console_ns.doc(description="Export all annotations for an app") @api.doc(description="Export all annotations for an app")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response( @api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
200, @api.response(403, "Insufficient permissions")
"Annotations exported successfully",
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -265,13 +260,13 @@ parser = (
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
class AnnotationUpdateDeleteApi(Resource): class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation") @api.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation") @api.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns)) @api.response(200, "Annotation updated successfully", annotation_fields)
@console_ns.response(204, "Annotation deleted successfully") @api.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.expect(parser) @api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -298,12 +293,12 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import") @console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
class AnnotationBatchImportApi(Resource): class AnnotationBatchImportApi(Resource):
@console_ns.doc("batch_import_annotations") @api.doc("batch_import_annotations")
@console_ns.doc(description="Batch import annotations from CSV file") @api.doc(description="Batch import annotations from CSV file")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Batch import started successfully") @api.response(200, "Batch import started successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(400, "No file uploaded or too many files") @api.response(400, "No file uploaded or too many files")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -328,11 +323,11 @@ class AnnotationBatchImportApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>") @console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
class AnnotationBatchImportStatusApi(Resource): class AnnotationBatchImportStatusApi(Resource):
@console_ns.doc("get_batch_import_status") @api.doc("get_batch_import_status")
@console_ns.doc(description="Get status of batch import job") @api.doc(description="Get status of batch import job")
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"}) @api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
@console_ns.response(200, "Job status retrieved successfully") @api.response(200, "Job status retrieved successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -355,27 +350,18 @@ class AnnotationBatchImportStatusApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories") @console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
class AnnotationHitHistoryListApi(Resource): class AnnotationHitHistoryListApi(Resource):
@console_ns.doc("list_annotation_hit_histories") @api.doc("list_annotation_hit_histories")
@console_ns.doc(description="Get hit histories for an annotation") @api.doc(description="Get hit histories for an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"}) @api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.expect( @api.expect(
console_ns.parser() api.parser()
.add_argument("page", type=int, location="args", default=1, help="Page number") .add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size") .add_argument("limit", type=int, location="args", default=20, help="Page size")
) )
@console_ns.response( @api.response(
200, 200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
"Hit histories retrieved successfully",
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
) )
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,35 +1,23 @@
import uuid import uuid
from typing import Literal
from flask import request from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
edit_permission_required, edit_permission_required,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
deleted_tool_fields,
model_config_fields,
model_config_partial_fields,
site_fields,
tag_fields,
)
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
from libs.helper import AppIconUrlField, TimestampField
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
from models import App, Workflow from models import App, Workflow
@ -39,243 +27,29 @@ from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
default="all", description="App mode filter"
)
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
if not value:
return None
if isinstance(value, str):
items = [item.strip() for item in value.split(",") if item.strip()]
elif isinstance(value, list):
items = [str(item).strip() for item in value if item and str(item).strip()]
else:
raise TypeError("Unsupported tag_ids type.")
if not items:
return None
try:
return [str(uuid.UUID(item)) for item in items]
except ValueError as exc:
raise ValueError("Invalid UUID format in tag_ids.") from exc
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
description: str | None = Field(default=None, description="Description for the copied app")
icon_type: str | None = Field(default=None, description="Icon type")
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("description")
@classmethod
def validate_description(cls, value: str | None) -> str | None:
if value is None:
return value
return validate_description_length(value)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_background: str | None = Field(default=None, description="Icon background color")
class AppSiteStatusPayload(BaseModel):
enable_site: bool = Field(..., description="Enable or disable site")
class AppApiStatusPayload(BaseModel):
enable_api: bool = Field(..., description="Enable or disable API")
class AppTracePayload(BaseModel):
enabled: bool = Field(..., description="Enable or disable tracing")
tracing_provider: str = Field(..., description="Tracing provider")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AppListQuery)
reg(CreateAppPayload)
reg(UpdateAppPayload)
reg(CopyAppPayload)
reg(AppExportQuery)
reg(AppNamePayload)
reg(AppIconPayload)
reg(AppSiteStatusPayload)
reg(AppApiStatusPayload)
reg(AppTracePayload)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base models first
tag_model = console_ns.model("Tag", tag_fields)
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
model_config_model = console_ns.model("ModelConfig", model_config_fields)
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
site_model = console_ns.model("Site", site_fields)
app_partial_model = console_ns.model(
"AppPartial",
{
"id": fields.String,
"name": fields.String,
"max_active_requests": fields.Raw(),
"description": fields.String(attribute="desc_or_prompt"),
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"tags": fields.List(fields.Nested(tag_model)),
"access_mode": fields.String,
"create_user_name": fields.String,
"author_name": fields.String,
"has_draft_trigger": fields.Boolean,
},
)
app_detail_model = console_ns.model(
"AppDetail",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon": fields.String,
"icon_background": fields.String,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"tracing": fields.Raw,
"use_icon_as_answer_icon": fields.Boolean,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
},
)
app_detail_with_site_model = console_ns.model(
"AppDetailWithSite",
{
"id": fields.String,
"name": fields.String,
"description": fields.String,
"mode": fields.String(attribute="mode_compatible_with_agent"),
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"enable_site": fields.Boolean,
"enable_api": fields.Boolean,
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
"max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
"updated_at": TimestampField,
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_model)),
"site": fields.Nested(site_model),
},
)
app_pagination_model = console_ns.model(
"AppPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
},
)
@console_ns.route("/apps") @console_ns.route("/apps")
class AppListApi(Resource): class AppListApi(Resource):
@console_ns.doc("list_apps") @api.doc("list_apps")
@console_ns.doc(description="Get list of applications with pagination and filtering") @api.doc(description="Get list of applications with pagination and filtering")
@console_ns.expect(console_ns.models[AppListQuery.__name__]) @api.expect(
@console_ns.response(200, "Success", app_pagination_model) api.parser()
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
.add_argument(
"mode",
type=str,
location="args",
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
default="all",
help="App mode filter",
)
.add_argument("name", type=str, location="args", help="Filter by app name")
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
)
@api.response(200, "Success", app_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -284,12 +58,42 @@ class AppListApi(Resource):
"""Get app list""" """Get app list"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore def uuid_list(value):
args_dict = args.model_dump() try:
return [str(uuid.UUID(v)) for v in value.split(",")]
except ValueError:
abort(400, message="Invalid UUID format in tag_ids.")
parser = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"mode",
type=str,
choices=[
"completion",
"chat",
"advanced-chat",
"workflow",
"agent-chat",
"channel",
"all",
],
default="all",
location="args",
required=False,
)
.add_argument("name", type=str, location="args", required=False)
.add_argument("tag_ids", type=uuid_list, location="args", required=False)
.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
)
args = parser.parse_args()
# get app list # get app list
app_service = AppService() app_service = AppService()
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict) app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args)
if not app_pagination: if not app_pagination:
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
@ -332,43 +136,67 @@ class AppListApi(Resource):
for app in app_pagination.items: for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
return marshal(app_pagination, app_pagination_model), 200 return marshal(app_pagination, app_pagination_fields), 200
@console_ns.doc("create_app") @api.doc("create_app")
@console_ns.doc(description="Create a new application") @api.doc(description="Create a new application")
@console_ns.expect(console_ns.models[CreateAppPayload.__name__]) @api.expect(
@console_ns.response(201, "App created successfully", app_detail_model) api.model(
@console_ns.response(403, "Insufficient permissions") "CreateAppRequest",
@console_ns.response(400, "Invalid request parameters") {
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App created successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_detail_model) @marshal_with(app_detail_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required @edit_permission_required
def post(self): def post(self):
"""Create app""" """Create app"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
args = CreateAppPayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
if "mode" not in args or args["mode"] is None:
raise BadRequest("mode is required")
app_service = AppService() app_service = AppService()
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user) app = app_service.create_app(current_tenant_id, args, current_user)
return app, 201 return app, 201
@console_ns.route("/apps/<uuid:app_id>") @console_ns.route("/apps/<uuid:app_id>")
class AppApi(Resource): class AppApi(Resource):
@console_ns.doc("get_app_detail") @api.doc("get_app_detail")
@console_ns.doc(description="Get application details") @api.doc(description="Get application details")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Success", app_detail_with_site_model) @api.response(200, "Success", app_detail_fields_with_site)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
@get_app_model @get_app_model
@marshal_with(app_detail_with_site_model) @marshal_with(app_detail_fields_with_site)
def get(self, app_model): def get(self, app_model):
"""Get app detail""" """Get app detail"""
app_service = AppService() app_service = AppService()
@ -381,43 +209,68 @@ class AppApi(Resource):
return app_model return app_model
@console_ns.doc("update_app") @api.doc("update_app")
@console_ns.doc(description="Update application details") @api.doc(description="Update application details")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__]) @api.expect(
@console_ns.response(200, "App updated successfully", app_detail_with_site_model) api.model(
@console_ns.response(403, "Insufficient permissions") "UpdateAppRequest",
@console_ns.response(400, "Invalid request parameters") {
"name": fields.String(required=True, description="App name"),
"description": fields.String(description="App description (max 400 chars)"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
"max_active_requests": fields.Integer(description="Maximum active requests"),
},
)
)
@api.response(200, "App updated successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required @edit_permission_required
@marshal_with(app_detail_with_site_model) @marshal_with(app_detail_fields_with_site)
def put(self, app_model): def put(self, app_model):
"""Update app""" """Update app"""
args = UpdateAppPayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=True, nullable=False, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
.add_argument("use_icon_as_answer_icon", type=bool, location="json")
.add_argument("max_active_requests", type=int, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
# Construct ArgsDict from parsed arguments
from services.app_service import AppService as AppServiceType
args_dict: AppService.ArgsDict = { args_dict: AppServiceType.ArgsDict = {
"name": args.name, "name": args["name"],
"description": args.description or "", "description": args.get("description", ""),
"icon_type": args.icon_type or "", "icon_type": args.get("icon_type", ""),
"icon": args.icon or "", "icon": args.get("icon", ""),
"icon_background": args.icon_background or "", "icon_background": args.get("icon_background", ""),
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, "use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
"max_active_requests": args.max_active_requests or 0, "max_active_requests": args.get("max_active_requests", 0),
} }
app_model = app_service.update_app(app_model, args_dict) app_model = app_service.update_app(app_model, args_dict)
return app_model return app_model
@console_ns.doc("delete_app") @api.doc("delete_app")
@console_ns.doc(description="Delete application") @api.doc(description="Delete application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(204, "App deleted successfully") @api.response(204, "App deleted successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -433,24 +286,43 @@ class AppApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/copy") @console_ns.route("/apps/<uuid:app_id>/copy")
class AppCopyApi(Resource): class AppCopyApi(Resource):
@console_ns.doc("copy_app") @api.doc("copy_app")
@console_ns.doc(description="Create a copy of an existing application") @api.doc(description="Create a copy of an existing application")
@console_ns.doc(params={"app_id": "Application ID to copy"}) @api.doc(params={"app_id": "Application ID to copy"})
@console_ns.expect(console_ns.models[CopyAppPayload.__name__]) @api.expect(
@console_ns.response(201, "App copied successfully", app_detail_with_site_model) api.model(
@console_ns.response(403, "Insufficient permissions") "CopyAppRequest",
{
"name": fields.String(description="Name for the copied app"),
"description": fields.String(description="Description for the copied app"),
"icon_type": fields.String(description="Icon type"),
"icon": fields.String(description="Icon"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(201, "App copied successfully", app_detail_fields_with_site)
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@edit_permission_required @edit_permission_required
@marshal_with(app_detail_with_site_model) @marshal_with(app_detail_fields_with_site)
def post(self, app_model): def post(self, app_model):
"""Copy app""" """Copy app"""
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = CopyAppPayload.model_validate(console_ns.payload or {}) parser = (
reqparse.RequestParser()
.add_argument("name", type=str, location="json")
.add_argument("description", type=validate_description_length, location="json")
.add_argument("icon_type", type=str, location="json")
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = AppDslService(session) import_service = AppDslService(session)
@ -459,11 +331,11 @@ class AppCopyApi(Resource):
account=current_user, account=current_user,
import_mode=ImportMode.YAML_CONTENT, import_mode=ImportMode.YAML_CONTENT,
yaml_content=yaml_content, yaml_content=yaml_content,
name=args.name, name=args.get("name"),
description=args.description, description=args.get("description"),
icon_type=args.icon_type, icon_type=args.get("icon_type"),
icon=args.icon, icon=args.get("icon"),
icon_background=args.icon_background, icon_background=args.get("icon_background"),
) )
session.commit() session.commit()
@ -475,16 +347,20 @@ class AppCopyApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/export") @console_ns.route("/apps/<uuid:app_id>/export")
class AppExportApi(Resource): class AppExportApi(Resource):
@console_ns.doc("export_app") @api.doc("export_app")
@console_ns.doc(description="Export application configuration as DSL") @api.doc(description="Export application configuration as DSL")
@console_ns.doc(params={"app_id": "Application ID to export"}) @api.doc(params={"app_id": "Application ID to export"})
@console_ns.expect(console_ns.models[AppExportQuery.__name__]) @api.expect(
@console_ns.response( api.parser()
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
)
@api.response(
200, 200,
"App exported successfully", "App exported successfully",
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}), api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
) )
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -492,114 +368,149 @@ class AppExportApi(Resource):
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
"""Export app""" """Export app"""
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore # Add include_secret params
parser = (
reqparse.RequestParser()
.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
.add_argument("workflow_id", type=str, location="args")
)
args = parser.parse_args()
return { return {
"data": AppDslService.export_dsl( "data": AppDslService.export_dsl(
app_model=app_model, app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
include_secret=args.include_secret,
workflow_id=args.workflow_id,
) )
} }
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json", help="Name to check")
@console_ns.route("/apps/<uuid:app_id>/name") @console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource): class AppNameApi(Resource):
@console_ns.doc("check_app_name") @api.doc("check_app_name")
@console_ns.doc(description="Check if app name is available") @api.doc(description="Check if app name is available")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppNamePayload.__name__]) @api.expect(parser)
@console_ns.response(200, "Name availability checked") @api.response(200, "Name availability checked")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_fields)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args = AppNamePayload.model_validate(console_ns.payload) args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_name(app_model, args.name) app_model = app_service.update_app_name(app_model, args["name"])
return app_model return app_model
@console_ns.route("/apps/<uuid:app_id>/icon") @console_ns.route("/apps/<uuid:app_id>/icon")
class AppIconApi(Resource): class AppIconApi(Resource):
@console_ns.doc("update_app_icon") @api.doc("update_app_icon")
@console_ns.doc(description="Update application icon") @api.doc(description="Update application icon")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppIconPayload.__name__]) @api.expect(
@console_ns.response(200, "Icon updated successfully") api.model(
@console_ns.response(403, "Insufficient permissions") "AppIconRequest",
{
"icon": fields.String(required=True, description="Icon data"),
"icon_type": fields.String(description="Icon type"),
"icon_background": fields.String(description="Icon background color"),
},
)
)
@api.response(200, "Icon updated successfully")
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_fields)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args = AppIconPayload.model_validate(console_ns.payload or {}) parser = (
reqparse.RequestParser()
.add_argument("icon", type=str, location="json")
.add_argument("icon_background", type=str, location="json")
)
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "") app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
return app_model return app_model
@console_ns.route("/apps/<uuid:app_id>/site-enable") @console_ns.route("/apps/<uuid:app_id>/site-enable")
class AppSiteStatus(Resource): class AppSiteStatus(Resource):
@console_ns.doc("update_app_site_status") @api.doc("update_app_site_status")
@console_ns.doc(description="Enable or disable app site") @api.doc(description="Enable or disable app site")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__]) @api.expect(
@console_ns.response(200, "Site status updated successfully", app_detail_model) api.model(
@console_ns.response(403, "Insufficient permissions") "AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
)
)
@api.response(200, "Site status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_fields)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args = AppSiteStatusPayload.model_validate(console_ns.payload) parser = reqparse.RequestParser().add_argument("enable_site", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_site_status(app_model, args.enable_site) app_model = app_service.update_app_site_status(app_model, args["enable_site"])
return app_model return app_model
@console_ns.route("/apps/<uuid:app_id>/api-enable") @console_ns.route("/apps/<uuid:app_id>/api-enable")
class AppApiStatus(Resource): class AppApiStatus(Resource):
@console_ns.doc("update_app_api_status") @api.doc("update_app_api_status")
@console_ns.doc(description="Enable or disable app API") @api.doc(description="Enable or disable app API")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__]) @api.expect(
@console_ns.response(200, "API status updated successfully", app_detail_model) api.model(
@console_ns.response(403, "Insufficient permissions") "AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
)
)
@api.response(200, "API status updated successfully", app_detail_fields)
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_detail_model) @marshal_with(app_detail_fields)
def post(self, app_model): def post(self, app_model):
args = AppApiStatusPayload.model_validate(console_ns.payload) # The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser().add_argument("enable_api", type=bool, required=True, location="json")
args = parser.parse_args()
app_service = AppService() app_service = AppService()
app_model = app_service.update_app_api_status(app_model, args.enable_api) app_model = app_service.update_app_api_status(app_model, args["enable_api"])
return app_model return app_model
@console_ns.route("/apps/<uuid:app_id>/trace") @console_ns.route("/apps/<uuid:app_id>/trace")
class AppTraceApi(Resource): class AppTraceApi(Resource):
@console_ns.doc("get_app_trace") @api.doc("get_app_trace")
@console_ns.doc(description="Get app tracing configuration") @api.doc(description="Get app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Trace configuration retrieved successfully") @api.response(200, "Trace configuration retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -609,24 +520,37 @@ class AppTraceApi(Resource):
return app_trace_config return app_trace_config
@console_ns.doc("update_app_trace") @api.doc("update_app_trace")
@console_ns.doc(description="Update app tracing configuration") @api.doc(description="Update app tracing configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppTracePayload.__name__]) @api.expect(
@console_ns.response(200, "Trace configuration updated successfully") api.model(
@console_ns.response(403, "Insufficient permissions") "AppTraceRequest",
{
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
"tracing_provider": fields.String(required=True, description="Tracing provider"),
},
)
)
@api.response(200, "Trace configuration updated successfully")
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required @edit_permission_required
def post(self, app_id): def post(self, app_id):
# add app trace # add app trace
args = AppTracePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("enabled", type=bool, required=True, location="json")
.add_argument("tracing_provider", type=str, required=True, location="json")
)
args = parser.parse_args()
OpsTraceManager.update_app_tracing_config( OpsTraceManager.update_app_tracing_config(
app_id=app_id, app_id=app_id,
enabled=args.enabled, enabled=args["enabled"],
tracing_provider=args.tracing_provider, tracing_provider=args["tracing_provider"],
) )
return {"result": "success"} return {"result": "success"}

View File

@ -1,6 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -9,11 +10,7 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import App from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
@ -22,19 +19,6 @@ from services.feature_service import FeatureService
from .. import console_ns from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("mode", type=str, required=True, location="json") .add_argument("mode", type=str, required=True, location="json")
@ -51,11 +35,11 @@ parser = (
@console_ns.route("/apps/imports") @console_ns.route("/apps/imports")
class AppImportApi(Resource): class AppImportApi(Resource):
@console_ns.expect(parser) @api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_model) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps") @cloud_edition_billing_resource_check("apps")
@edit_permission_required @edit_permission_required
def post(self): def post(self):
@ -98,7 +82,7 @@ class AppImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_model) @marshal_with(app_import_fields)
@edit_permission_required @edit_permission_required
def post(self, import_id): def post(self, import_id):
# Check user role first # Check user role first
@ -124,7 +108,7 @@ class AppImportCheckDependenciesApi(Resource):
@login_required @login_required
@get_app_model @get_app_model
@account_initialization_required @account_initialization_required
@marshal_with(app_import_check_dependencies_model) @marshal_with(app_import_check_dependencies_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
with Session(db.engine) as session: with Session(db.engine) as session:

View File

@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@ -36,16 +36,16 @@ logger = logging.getLogger(__name__)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text") @console_ns.route("/apps/<uuid:app_id>/audio-to-text")
class ChatMessageAudioApi(Resource): class ChatMessageAudioApi(Resource):
@console_ns.doc("chat_message_audio_transcript") @api.doc("chat_message_audio_transcript")
@console_ns.doc(description="Transcript audio to text for chat messages") @api.doc(description="Transcript audio to text for chat messages")
@console_ns.doc(params={"app_id": "App ID"}) @api.doc(params={"app_id": "App ID"})
@console_ns.response( @api.response(
200, 200,
"Audio transcription successful", "Audio transcription successful",
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}), api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
) )
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type") @api.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large") @api.response(413, "Audio file too large")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -89,11 +89,11 @@ class ChatMessageAudioApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio") @console_ns.route("/apps/<uuid:app_id>/text-to-audio")
class ChatMessageTextApi(Resource): class ChatMessageTextApi(Resource):
@console_ns.doc("chat_message_text_to_speech") @api.doc("chat_message_text_to_speech")
@console_ns.doc(description="Convert text to speech for chat messages") @api.doc(description="Convert text to speech for chat messages")
@console_ns.doc(params={"app_id": "App ID"}) @api.doc(params={"app_id": "App ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"TextToSpeechRequest", "TextToSpeechRequest",
{ {
"message_id": fields.String(description="Message ID"), "message_id": fields.String(description="Message ID"),
@ -103,8 +103,8 @@ class ChatMessageTextApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Text to speech conversion successful") @api.response(200, "Text to speech conversion successful")
@console_ns.response(400, "Bad request - Invalid parameters") @api.response(400, "Bad request - Invalid parameters")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -156,16 +156,12 @@ class ChatMessageTextApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices") @console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
class TextModesApi(Resource): class TextModesApi(Resource):
@console_ns.doc("get_text_to_speech_voices") @api.doc("get_text_to_speech_voices")
@console_ns.doc(description="Get available TTS voices for a specific language") @api.doc(description="Get available TTS voices for a specific language")
@console_ns.doc(params={"app_id": "App ID"}) @api.doc(params={"app_id": "App ID"})
@console_ns.expect( @api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code") @api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
) @api.response(400, "Invalid language parameter")
@console_ns.response(
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
)
@console_ns.response(400, "Invalid language parameter")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required

View File

@ -1,13 +1,11 @@
import logging import logging
from typing import Any, Literal
from flask import request from flask import request
from flask_restx import Resource from flask_restx import Resource, fields, reqparse
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@ -19,6 +17,7 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -33,66 +32,50 @@ from libs.login import current_user, login_required
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseMessagePayload(BaseModel):
inputs: dict[str, Any]
model_config_data: dict[str, Any] = Field(..., alias="model_config")
files: list[Any] | None = Field(default=None, description="Uploaded files")
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
retriever_from: str = Field(default="dev", description="Retriever source")
class CompletionMessagePayload(BaseMessagePayload):
query: str = Field(default="", description="Query text")
class ChatMessagePayload(BaseMessagePayload):
query: str = Field(..., description="User query")
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
console_ns.schema_model(
CompletionMessagePayload.__name__,
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# define completion message api for user # define completion message api for user
@console_ns.route("/apps/<uuid:app_id>/completion-messages") @console_ns.route("/apps/<uuid:app_id>/completion-messages")
class CompletionMessageApi(Resource): class CompletionMessageApi(Resource):
@console_ns.doc("create_completion_message") @api.doc("create_completion_message")
@console_ns.doc(description="Generate completion message for debugging") @api.doc(description="Generate completion message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__]) @api.expect(
@console_ns.response(200, "Completion generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "CompletionMessageRequest",
@console_ns.response(404, "App not found") {
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(description="Query text", default=""),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Completion generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model): def post(self, app_model):
args_model = CompletionMessagePayload.model_validate(console_ns.payload) parser = (
args = args_model.model_dump(exclude_none=True, by_alias=True) reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking" streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
try: try:
@ -127,10 +110,10 @@ class CompletionMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
class CompletionMessageStopApi(Resource): class CompletionMessageStopApi(Resource):
@console_ns.doc("stop_completion_message") @api.doc("stop_completion_message")
@console_ns.doc(description="Stop a running completion message generation") @api.doc(description="Stop a running completion message generation")
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@console_ns.response(200, "Task stopped successfully") @api.response(200, "Task stopped successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -138,36 +121,54 @@ class CompletionMessageStopApi(Resource):
def post(self, app_model, task_id): def post(self, app_model, task_id):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.route("/apps/<uuid:app_id>/chat-messages") @console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageApi(Resource): class ChatMessageApi(Resource):
@console_ns.doc("create_chat_message") @api.doc("create_chat_message")
@console_ns.doc(description="Generate chat message for debugging") @api.doc(description="Generate chat message for debugging")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) @api.expect(
@console_ns.response(200, "Chat message generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "ChatMessageRequest",
@console_ns.response(404, "App or conversation not found") {
"inputs": fields.Raw(required=True, description="Input variables"),
"query": fields.String(required=True, description="User query"),
"files": fields.List(fields.Raw(), description="Uploaded files"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"conversation_id": fields.String(description="Conversation ID"),
"parent_message_id": fields.String(description="Parent message ID"),
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
"retriever_from": fields.String(default="dev", description="Retriever source"),
},
)
)
@api.response(200, "Chat message generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(404, "App or conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
args_model = ChatMessagePayload.model_validate(console_ns.payload) parser = (
args = args_model.model_dump(exclude_none=True, by_alias=True) reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("model_config", type=dict, required=True, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
)
args = parser.parse_args()
streaming = args_model.response_mode != "blocking" streaming = args["response_mode"] != "blocking"
args["auto_generate_name"] = False args["auto_generate_name"] = False
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
@ -208,10 +209,10 @@ class ChatMessageApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
class ChatMessageStopApi(Resource): class ChatMessageStopApi(Resource):
@console_ns.doc("stop_chat_message") @api.doc("stop_chat_message")
@console_ns.doc(description="Stop a running chat message generation") @api.doc(description="Stop a running chat message generation")
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"}) @api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
@console_ns.response(200, "Task stopped successfully") @api.response(200, "Task stopped successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -219,12 +220,6 @@ class ChatMessageStopApi(Resource):
def post(self, app_model, task_id): def post(self, app_model, task_id):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.DEBUGGER,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,353 +1,88 @@
from typing import Literal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort, request from flask import abort
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, marshal_with, reqparse
from pydantic import BaseModel, Field, field_validator from flask_restx.inputs import int_range
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_fields import MessageTextField from fields.conversation_fields import (
from fields.raws import FilesContainedField conversation_detail_fields,
conversation_message_detail_fields,
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
from libs.datetime_utils import naive_utc_now, parse_time_range from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
from models.model import AppMode from models.model import AppMode
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class BaseConversationQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword")
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
default="all", description="Annotation status filter"
)
page: int = Field(default=1, ge=1, le=99999, description="Page number")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
class CompletionConversationQuery(BaseConversationQuery):
pass
class ChatConversationQuery(BaseConversationQuery):
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort field and direction"
)
console_ns.schema_model(
CompletionConversationQuery.__name__,
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ChatConversationQuery.__name__,
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
feedback_stat_model = console_ns.model(
"FeedbackStat",
{
"like": fields.Integer,
"dislike": fields.Integer,
},
)
status_count_model = console_ns.model(
"StatusCount",
{
"success": fields.Integer,
"failed": fields.Integer,
"partial_success": fields.Integer,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
simple_model_config_model = console_ns.model(
"SimpleModelConfig",
{
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
},
)
model_config_model = console_ns.model(
"ModelConfig",
{
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",
{
"inputs": FilesContainedField,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Conversation models
conversation_fields_model = console_ns.model(
"Conversation",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"from_account_name": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotation": fields.Nested(annotation_model, allow_null=True),
"model_config": fields.Nested(simple_model_config_model),
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
},
)
conversation_pagination_model = console_ns.model(
"ConversationPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
},
)
conversation_message_detail_model = console_ns.model(
"ConversationMessageDetail",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_model),
"message": fields.Nested(message_detail_model, attribute="first_message"),
},
)
conversation_with_summary_model = console_ns.model(
"ConversationWithSummary",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"from_account_name": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_model),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
"status_count": fields.Nested(status_count_model),
},
)
conversation_with_summary_pagination_model = console_ns.model(
"ConversationWithSummaryPagination",
{
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
},
)
conversation_detail_model = console_ns.model(
"ConversationDetail",
{
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_model),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_model),
"admin_feedback_stats": fields.Nested(feedback_stat_model),
},
)
@console_ns.route("/apps/<uuid:app_id>/completion-conversations") @console_ns.route("/apps/<uuid:app_id>/completion-conversations")
class CompletionConversationApi(Resource): class CompletionConversationApi(Resource):
@console_ns.doc("list_completion_conversations") @api.doc("list_completion_conversations")
@console_ns.doc(description="Get completion conversations with pagination and filtering") @api.doc(description="Get completion conversations with pagination and filtering")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__]) @api.expect(
@console_ns.response(200, "Success", conversation_pagination_model) api.parser()
@console_ns.response(403, "Insufficient permissions") .add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
)
@api.response(200, "Success", conversation_pagination_fields)
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_pagination_model) @marshal_with(conversation_pagination_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
query = sa.select(Conversation).where( query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False) Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
) )
if args.keyword: if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).where( query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_( or_(
Message.query.ilike(f"%{args.keyword}%"), Message.query.ilike(f"%{args['keyword']}%"),
Message.answer.ilike(f"%{args.keyword}%"), Message.answer.ilike(f"%{args['keyword']}%"),
) )
) )
@ -355,7 +90,7 @@ class CompletionConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -367,11 +102,11 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file # FIXME, the type ignore in this file
if args.annotation_status == "annotated": if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args.annotation_status == "not_annotated": elif args["annotation_status"] == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
@ -380,36 +115,36 @@ class CompletionConversationApi(Resource):
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations return conversations
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>") @console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
class CompletionConversationDetailApi(Resource): class CompletionConversationDetailApi(Resource):
@console_ns.doc("get_completion_conversation") @api.doc("get_completion_conversation")
@console_ns.doc(description="Get completion conversation details with messages") @api.doc(description="Get completion conversation details with messages")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(200, "Success", conversation_message_detail_model) @api.response(200, "Success", conversation_message_detail_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.COMPLETION) @get_app_model(mode=AppMode.COMPLETION)
@marshal_with(conversation_message_detail_model) @marshal_with(conversation_message_detail_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@console_ns.doc("delete_completion_conversation") @api.doc("delete_completion_conversation")
@console_ns.doc(description="Delete a completion conversation") @api.doc(description="Delete a completion conversation")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(204, "Conversation deleted successfully") @api.response(204, "Conversation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -429,21 +164,69 @@ class CompletionConversationDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-conversations") @console_ns.route("/apps/<uuid:app_id>/chat-conversations")
class ChatConversationApi(Resource): class ChatConversationApi(Resource):
@console_ns.doc("list_chat_conversations") @api.doc("list_chat_conversations")
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary") @api.doc(description="Get chat conversations with pagination, filtering and summary")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__]) @api.expect(
@console_ns.response(200, "Success", conversation_with_summary_pagination_model) api.parser()
@console_ns.response(403, "Insufficient permissions") .add_argument("keyword", type=str, location="args", help="Search keyword")
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
.add_argument(
"annotation_status",
type=str,
location="args",
choices=["annotated", "not_annotated", "all"],
default="all",
help="Annotation status filter",
)
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
.add_argument("page", type=int, location="args", default=1, help="Page number")
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
.add_argument(
"sort_by",
type=str,
location="args",
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
default="-updated_at",
help="Sort field and direction",
)
)
@api.response(200, "Success", conversation_with_summary_pagination_fields)
@api.response(403, "Insufficient permissions")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_with_summary_pagination_model) @marshal_with(conversation_with_summary_pagination_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument(
"annotation_status",
type=str,
choices=["annotated", "not_annotated", "all"],
default="all",
location="args",
)
.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
.add_argument(
"sort_by",
type=str,
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
required=False,
default="-updated_at",
location="args",
)
)
args = parser.parse_args()
subquery = ( subquery = (
db.session.query( db.session.query(
@ -455,8 +238,8 @@ class ChatConversationApi(Resource):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False)) query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword: if args["keyword"]:
keyword_filter = f"%{args.keyword}%" keyword_filter = f"%{args['keyword']}%"
query = ( query = (
query.join( query.join(
Message, Message,
@ -479,12 +262,12 @@ class ChatConversationApi(Resource):
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
if start_datetime_utc: if start_datetime_utc:
match args.sort_by: match args["sort_by"]:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc) query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
@ -492,35 +275,35 @@ class ChatConversationApi(Resource):
if end_datetime_utc: if end_datetime_utc:
end_datetime_utc = end_datetime_utc.replace(second=59) end_datetime_utc = end_datetime_utc.replace(second=59)
match args.sort_by: match args["sort_by"]:
case "updated_at" | "-updated_at": case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc) query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _: case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc) query = query.where(Conversation.created_at <= end_datetime_utc)
if args.annotation_status == "annotated": if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args.annotation_status == "not_annotated": elif args["annotation_status"] == "not_annotated":
query = ( query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0) .having(func.count(MessageAnnotation.id) == 0)
) )
if args.message_count_gte and args.message_count_gte >= 1: if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = ( query = (
query.options(joinedload(Conversation.messages)) # type: ignore query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id) .join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(Message.id) >= args.message_count_gte) .having(func.count(Message.id) >= args["message_count_gte"])
) )
if app_model.mode == AppMode.ADVANCED_CHAT: if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER) query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
match args.sort_by: match args["sort_by"]:
case "created_at": case "created_at":
query = query.order_by(Conversation.created_at.asc()) query = query.order_by(Conversation.created_at.asc())
case "-created_at": case "-created_at":
@ -532,36 +315,36 @@ class ChatConversationApi(Resource):
case _: case _:
query = query.order_by(Conversation.created_at.desc()) query = query.order_by(Conversation.created_at.desc())
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False) conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
return conversations return conversations
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>") @console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
class ChatConversationDetailApi(Resource): class ChatConversationDetailApi(Resource):
@console_ns.doc("get_chat_conversation") @api.doc("get_chat_conversation")
@console_ns.doc(description="Get chat conversation details") @api.doc(description="Get chat conversation details")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(200, "Success", conversation_detail_model) @api.response(200, "Success", conversation_detail_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(conversation_detail_model) @marshal_with(conversation_detail_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model, conversation_id): def get(self, app_model, conversation_id):
conversation_id = str(conversation_id) conversation_id = str(conversation_id)
return _get_conversation(app_model, conversation_id) return _get_conversation(app_model, conversation_id)
@console_ns.doc("delete_chat_conversation") @api.doc("delete_chat_conversation")
@console_ns.doc(description="Delete a chat conversation") @api.doc(description="Delete a chat conversation")
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"}) @api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
@console_ns.response(204, "Conversation deleted successfully") @api.response(204, "Conversation deleted successfully")
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Conversation not found") @api.response(404, "Conversation not found")
@setup_required @setup_required
@login_required @login_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])

View File

@ -1,68 +1,46 @@
from flask import request from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.conversation_variable_fields import ( from fields.conversation_variable_fields import paginated_conversation_variable_fields
conversation_variable_fields,
paginated_conversation_variable_fields,
)
from libs.login import login_required from libs.login import login_required
from models import ConversationVariable from models import ConversationVariable
from models.model import AppMode from models.model import AppMode
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ConversationVariablesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID to filter variables")
console_ns.schema_model(
ConversationVariablesQuery.__name__,
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
# For nested models, need to replace nested dict with registered model
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
paginated_conversation_variable_fields_copy["data"] = fields.List(
fields.Nested(conversation_variable_model), attribute="data"
)
paginated_conversation_variable_model = console_ns.model(
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
)
@console_ns.route("/apps/<uuid:app_id>/conversation-variables") @console_ns.route("/apps/<uuid:app_id>/conversation-variables")
class ConversationVariablesApi(Resource): class ConversationVariablesApi(Resource):
@console_ns.doc("get_conversation_variables") @api.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for an application") @api.doc(description="Get conversation variables for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__]) @api.expect(
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model) api.parser().add_argument(
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
)
)
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=AppMode.ADVANCED_CHAT) @get_app_model(mode=AppMode.ADVANCED_CHAT)
@marshal_with(paginated_conversation_variable_model) @marshal_with(paginated_conversation_variable_fields)
def get(self, app_model): def get(self, app_model):
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = reqparse.RequestParser().add_argument("conversation_id", type=str, location="args")
args = parser.parse_args()
stmt = ( stmt = (
select(ConversationVariable) select(ConversationVariable)
.where(ConversationVariable.app_id == app_model.id) .where(ConversationVariable.app_id == app_model.id)
.order_by(ConversationVariable.created_at) .order_by(ConversationVariable.created_at)
) )
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id) if args["conversation_id"]:
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
else:
raise ValueError("conversation_id is required")
# NOTE: This is a temporary solution to avoid performance issues. # NOTE: This is a temporary solution to avoid performance issues.
page = 1 page = 1

View File

@ -1,10 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
from flask_restx import Resource from flask_restx import Resource, fields, reqparse
from pydantic import BaseModel, Field
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -23,70 +21,43 @@ from libs.login import current_account_with_tenant, login_required
from models import App from models import App
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(RuleGeneratePayload)
reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
@console_ns.route("/rule-generate") @console_ns.route("/rule-generate")
class RuleGenerateApi(Resource): class RuleGenerateApi(Resource):
@console_ns.doc("generate_rule_config") @api.doc("generate_rule_config")
@console_ns.doc(description="Generate rule configuration using LLM") @api.doc(description="Generate rule configuration using LLM")
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__]) @api.expect(
@console_ns.response(200, "Rule configuration generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "RuleGenerateRequest",
@console_ns.response(402, "Provider quota exceeded") {
"instruction": fields.String(required=True, description="Rule generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
},
)
)
@api.response(200, "Rule configuration generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
no_variable=args.no_variable, no_variable=args["no_variable"],
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -102,25 +73,44 @@ class RuleGenerateApi(Resource):
@console_ns.route("/rule-code-generate") @console_ns.route("/rule-code-generate")
class RuleCodeGenerateApi(Resource): class RuleCodeGenerateApi(Resource):
@console_ns.doc("generate_rule_code") @api.doc("generate_rule_code")
@console_ns.doc(description="Generate code rules using LLM") @api.doc(description="Generate code rules using LLM")
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__]) @api.expect(
@console_ns.response(200, "Code rules generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "RuleCodeGenerateRequest",
@console_ns.response(402, "Provider quota exceeded") {
"instruction": fields.String(required=True, description="Code generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
"code_language": fields.String(
default="javascript", description="Programming language for code generation"
),
},
)
)
@api.response(200, "Code rules generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("no_variable", type=bool, required=True, default=False, location="json")
.add_argument("code_language", type=str, required=False, default="javascript", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
code_language=args.code_language, code_language=args["code_language"],
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -136,24 +126,37 @@ class RuleCodeGenerateApi(Resource):
@console_ns.route("/rule-structured-output-generate") @console_ns.route("/rule-structured-output-generate")
class RuleStructuredOutputGenerateApi(Resource): class RuleStructuredOutputGenerateApi(Resource):
@console_ns.doc("generate_structured_output") @api.doc("generate_structured_output")
@console_ns.doc(description="Generate structured output rules using LLM") @api.doc(description="Generate structured output rules using LLM")
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__]) @api.expect(
@console_ns.response(200, "Structured output generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "StructuredOutputGenerateRequest",
@console_ns.response(402, "Provider quota exceeded") {
"instruction": fields.String(required=True, description="Structured output generation instruction"),
"model_config": fields.Raw(required=True, description="Model configuration"),
},
)
)
@api.response(200, "Structured output generated successfully")
@api.response(400, "Invalid request parameters")
@api.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
try: try:
structured_output = LLMGenerator.generate_structured_output( structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -169,79 +172,102 @@ class RuleStructuredOutputGenerateApi(Resource):
@console_ns.route("/instruction-generate") @console_ns.route("/instruction-generate")
class InstructionGenerateApi(Resource): class InstructionGenerateApi(Resource):
@console_ns.doc("generate_instruction") @api.doc("generate_instruction")
@console_ns.doc(description="Generate instruction for workflow nodes or general use") @api.doc(description="Generate instruction for workflow nodes or general use")
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__]) @api.expect(
@console_ns.response(200, "Instruction generated successfully") api.model(
@console_ns.response(400, "Invalid request parameters or flow/workflow not found") "InstructionGenerateRequest",
@console_ns.response(402, "Provider quota exceeded") {
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
"node_id": fields.String(description="Node ID for workflow context"),
"current": fields.String(description="Current instruction text"),
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
"instruction": fields.String(required=True, description="Instruction for generation"),
"model_config": fields.Raw(required=True, description="Model configuration"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Instruction generated successfully")
@api.response(400, "Invalid request parameters or flow/workflow not found")
@api.response(402, "Provider quota exceeded")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("flow_id", type=str, required=True, default="", location="json")
.add_argument("node_id", type=str, required=False, default="", location="json")
.add_argument("current", type=str, required=False, default="", location="json")
.add_argument("language", type=str, required=False, default="javascript", location="json")
.add_argument("instruction", type=str, required=True, nullable=False, location="json")
.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
.add_argument("ideal_output", type=str, required=False, default="", location="json")
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider] providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next( code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None (p for p in providers if p.is_accept_language(args["language"])), None
) )
code_template = code_provider.get_default_code() if code_provider else "" code_template = code_provider.get_default_code() if code_provider else ""
try: try:
# Generate from nothing for a workflow node # Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "": if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
app = db.session.query(App).where(App.id == args.flow_id).first() app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app: if not app:
return {"error": f"app {args.flow_id} not found"}, 400 return {"error": f"app {args['flow_id']} not found"}, 400
workflow = WorkflowService().get_draft_workflow(app_model=app) workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow: if not workflow:
return {"error": f"workflow {args.flow_id} not found"}, 400 return {"error": f"workflow {args['flow_id']} not found"}, 400
nodes: Sequence = workflow.graph_dict["nodes"] nodes: Sequence = workflow.graph_dict["nodes"]
node = [node for node in nodes if node["id"] == args.node_id] node = [node for node in nodes if node["id"] == args["node_id"]]
if len(node) == 0: if len(node) == 0:
return {"error": f"node {args.node_id} not found"}, 400 return {"error": f"node {args['node_id']} not found"}, 400
node_type = node[0]["data"]["type"] node_type = node[0]["data"]["type"]
match node_type: match node_type:
case "llm": case "llm":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "agent": case "agent":
return LLMGenerator.generate_rule_config( return LLMGenerator.generate_rule_config(
current_tenant_id, current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
no_variable=True, no_variable=True,
) )
case "code": case "code":
return LLMGenerator.generate_code( return LLMGenerator.generate_code(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
code_language=args.language, code_language=args["language"],
) )
case _: case _:
return {"error": f"invalid node type: {node_type}"} return {"error": f"invalid node type: {node_type}"}
if args.node_id == "" and args.current != "": # For legacy app without a workflow if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
return LLMGenerator.instruction_modify_legacy( return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args.flow_id, flow_id=args["flow_id"],
current=args.current, current=args["current"],
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
ideal_output=args.ideal_output, ideal_output=args["ideal_output"],
) )
if args.node_id != "" and args.current != "": # For workflow node if args["node_id"] != "" and args["current"] != "": # For workflow node
return LLMGenerator.instruction_modify_workflow( return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
flow_id=args.flow_id, flow_id=args["flow_id"],
node_id=args.node_id, node_id=args["node_id"],
current=args.current, current=args["current"],
instruction=args.instruction, instruction=args["instruction"],
model_config=args.model_config_data, model_config=args["model_config"],
ideal_output=args.ideal_output, ideal_output=args["ideal_output"],
workflow_service=WorkflowService(), workflow_service=WorkflowService(),
) )
return {"error": "incompatible parameters"}, 400 return {"error": "incompatible parameters"}, 400
@ -257,17 +283,26 @@ class InstructionGenerateApi(Resource):
@console_ns.route("/instruction-generate/template") @console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource): class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template") @api.doc("get_instruction_template")
@console_ns.doc(description="Get instruction generation template") @api.doc(description="Get instruction generation template")
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__]) @api.expect(
@console_ns.response(200, "Template retrieved successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "InstructionTemplateRequest",
{
"instruction": fields.String(required=True, description="Template instruction"),
"ideal_output": fields.String(description="Expected ideal output"),
},
)
)
@api.response(200, "Template retrieved successfully")
@api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = InstructionTemplatePayload.model_validate(console_ns.payload) parser = reqparse.RequestParser().add_argument("type", type=str, required=True, default=False, location="json")
match args.type: args = parser.parse_args()
match args["type"]:
case "prompt": case "prompt":
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
@ -277,4 +312,4 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE} return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _: case _:
raise ValueError(f"Invalid type: {args.type}") raise ValueError(f"Invalid type: {args['type']}")

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
@ -12,9 +12,6 @@ from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMCPServer from models.model import AppMCPServer
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
class AppMCPServerStatus(StrEnum): class AppMCPServerStatus(StrEnum):
ACTIVE = "active" ACTIVE = "active"
@ -23,24 +20,24 @@ class AppMCPServerStatus(StrEnum):
@console_ns.route("/apps/<uuid:app_id>/server") @console_ns.route("/apps/<uuid:app_id>/server")
class AppMCPServerController(Resource): class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server") @api.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application") @api.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model) @api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required @setup_required
@get_app_model @get_app_model
@marshal_with(app_server_model) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server return server
@console_ns.doc("create_app_mcp_server") @api.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application") @api.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"MCPServerCreateRequest", "MCPServerCreateRequest",
{ {
"description": fields.String(description="Server description"), "description": fields.String(description="Server description"),
@ -48,13 +45,13 @@ class AppMCPServerController(Resource):
}, },
) )
) )
@console_ns.response(201, "MCP server configuration created successfully", app_server_model) @api.response(201, "MCP server configuration created successfully", app_server_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@login_required @login_required
@setup_required @setup_required
@marshal_with(app_server_model) @marshal_with(app_server_fields)
@edit_permission_required @edit_permission_required
def post(self, app_model): def post(self, app_model):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
@ -82,11 +79,11 @@ class AppMCPServerController(Resource):
db.session.commit() db.session.commit()
return server return server
@console_ns.doc("update_app_mcp_server") @api.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application") @api.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"MCPServerUpdateRequest", "MCPServerUpdateRequest",
{ {
"id": fields.String(required=True, description="Server ID"), "id": fields.String(required=True, description="Server ID"),
@ -96,14 +93,14 @@ class AppMCPServerController(Resource):
}, },
) )
) )
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model) @api.response(200, "MCP server configuration updated successfully", app_server_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found") @api.response(404, "Server not found")
@get_app_model @get_app_model
@login_required @login_required
@setup_required @setup_required
@account_initialization_required @account_initialization_required
@marshal_with(app_server_model) @marshal_with(app_server_fields)
@edit_permission_required @edit_permission_required
def put(self, app_model): def put(self, app_model):
parser = ( parser = (
@ -137,16 +134,16 @@ class AppMCPServerController(Resource):
@console_ns.route("/apps/<uuid:server_id>/server/refresh") @console_ns.route("/apps/<uuid:server_id>/server/refresh")
class AppMCPServerRefreshController(Resource): class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server") @api.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code") @api.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"}) @api.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "MCP server refreshed successfully", app_server_model) @api.response(200, "MCP server refreshed successfully", app_server_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found") @api.response(404, "Server not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_server_model) @marshal_with(app_server_fields)
@edit_permission_required @edit_permission_required
def get(self, server_id): def get(self, server_id):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()

View File

@ -1,13 +1,11 @@
import logging import logging
from typing import Literal
from flask import request from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with from flask_restx.inputs import int_range
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -25,8 +23,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from fields.raws import FilesContainedField from fields.conversation_fields import message_detail_fields
from libs.helper import TimestampField, uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
@ -35,216 +33,55 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService from services.message_service import MessageService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model(
"SimpleAccount",
{
"id": fields.String,
"name": fields.String,
"email": fields.String,
},
)
message_file_model = console_ns.model(
"MessageFile",
{
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
},
)
agent_thought_model = console_ns.model(
"AgentThought",
{
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
},
)
# Models that depend on simple_account_model
feedback_model = console_ns.model(
"Feedback",
{
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_model, allow_null=True),
},
)
annotation_model = console_ns.model(
"Annotation",
{
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
annotation_hit_history_model = console_ns.model(
"AnnotationHitHistory",
{
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
"created_at": TimestampField,
},
)
# Message detail model that depends on multiple models
message_detail_model = console_ns.model(
"MessageDetail",
{
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_model)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_model, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
},
)
# Message infinite scroll pagination model
message_infinite_scroll_pagination_model = console_ns.model(
"MessageInfiniteScrollPagination",
{
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_detail_model)),
},
)
@console_ns.route("/apps/<uuid:app_id>/chat-messages") @console_ns.route("/apps/<uuid:app_id>/chat-messages")
class ChatMessageListApi(Resource): class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages") message_infinite_scroll_pagination_fields = {
@console_ns.doc(description="Get chat messages for a conversation with pagination") "limit": fields.Integer,
@console_ns.doc(params={"app_id": "Application ID"}) "has_more": fields.Boolean,
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__]) "data": fields.List(fields.Nested(message_detail_fields)),
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model) }
@console_ns.response(404, "Conversation not found")
@api.doc("list_chat_messages")
@api.doc(description="Get chat messages for a conversation with pagination")
@api.doc(params={"app_id": "Application ID"})
@api.expect(
api.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
@api.response(404, "Conversation not found")
@login_required @login_required
@account_initialization_required @account_initialization_required
@setup_required @setup_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@marshal_with(message_infinite_scroll_pagination_model) @marshal_with(message_infinite_scroll_pagination_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model): def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first() .first()
) )
if not conversation: if not conversation:
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if args.first_id: if args["first_id"]:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args.first_id) .where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first() .first()
) )
@ -259,7 +96,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id, Message.id != first_message.id,
) )
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args.limit) .limit(args["limit"])
.all() .all()
) )
else: else:
@ -267,12 +104,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args.limit) .limit(args["limit"])
.all() .all()
) )
# Initialize has_more based on whether we have a full page # Initialize has_more based on whether we have a full page
if len(history_messages) == args.limit: if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page # Check if there are more messages before the current page
has_more = db.session.scalar( has_more = db.session.scalar(
@ -290,18 +127,26 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages)) history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more) return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks") @console_ns.route("/apps/<uuid:app_id>/feedbacks")
class MessageFeedbackApi(Resource): class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback") @api.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)") @api.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) @api.expect(
@console_ns.response(200, "Feedback updated successfully") api.model(
@console_ns.response(404, "Message not found") "MessageFeedbackRequest",
@console_ns.response(403, "Insufficient permissions") {
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@api.response(200, "Feedback updated successfully")
@api.response(404, "Message not found")
@api.response(403, "Insufficient permissions")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -309,9 +154,14 @@ class MessageFeedbackApi(Resource):
def post(self, app_model): def post(self, app_model):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = MessageFeedbackPayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
message_id = str(args.message_id) message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@ -320,21 +170,18 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback feedback = message.admin_feedback
if not args.rating and feedback: if not args["rating"] and feedback:
db.session.delete(feedback) db.session.delete(feedback)
elif args.rating and feedback: elif args["rating"] and feedback:
feedback.rating = args.rating feedback.rating = args["rating"]
elif not args.rating and not feedback: elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists") raise ValueError("rating cannot be None when feedback not exists")
else: else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback( feedback = MessageFeedback(
app_id=app_model.id, app_id=app_model.id,
conversation_id=message.conversation_id, conversation_id=message.conversation_id,
message_id=message.id, message_id=message.id,
rating=rating_value, rating=args["rating"],
from_source="admin", from_source="admin",
from_account_id=current_user.id, from_account_id=current_user.id,
) )
@ -347,13 +194,13 @@ class MessageFeedbackApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/annotations/count") @console_ns.route("/apps/<uuid:app_id>/annotations/count")
class MessageAnnotationCountApi(Resource): class MessageAnnotationCountApi(Resource):
@console_ns.doc("get_annotation_count") @api.doc("get_annotation_count")
@console_ns.doc(description="Get count of message annotations for the app") @api.doc(description="Get count of message annotations for the app")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response( @api.response(
200, 200,
"Annotation count retrieved successfully", "Annotation count retrieved successfully",
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}), api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
) )
@get_app_model @get_app_model
@setup_required @setup_required
@ -367,17 +214,15 @@ class MessageAnnotationCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions") @console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(Resource): class MessageSuggestedQuestionApi(Resource):
@console_ns.doc("get_message_suggested_questions") @api.doc("get_message_suggested_questions")
@console_ns.doc(description="Get suggested questions for a message") @api.doc(description="Get suggested questions for a message")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response( @api.response(
200, 200,
"Suggested questions retrieved successfully", "Suggested questions retrieved successfully",
console_ns.model( api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
) )
@console_ns.response(404, "Message or conversation not found") @api.response(404, "Message or conversation not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -411,58 +256,18 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions} return {"data": questions}
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@get_app_model
@setup_required
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
)
return export_data
except ValueError as e:
logger.exception("Parameter validation error in feedback export")
return {"error": f"Parameter validation error: {str(e)}"}, 400
except Exception as e:
logger.exception("Error exporting feedback data")
raise InternalServerError(str(e))
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>") @console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
class MessageApi(Resource): class MessageApi(Resource):
@console_ns.doc("get_message") @api.doc("get_message")
@console_ns.doc(description="Get message details by ID") @api.doc(description="Get message details by ID")
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"}) @api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
@console_ns.response(200, "Message retrieved successfully", message_detail_model) @api.response(200, "Message retrieved successfully", message_detail_fields)
@console_ns.response(404, "Message not found") @api.response(404, "Message not found")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(message_detail_model) @marshal_with(message_detail_fields)
def get(self, app_model, message_id: str): def get(self, app_model, message_id: str):
message_id = str(message_id) message_id = str(message_id)

View File

@ -3,10 +3,11 @@ from typing import cast
from flask import request from flask import request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.configuration import ToolParameterConfigurationManager
@ -20,11 +21,11 @@ from services.app_model_config_service import AppModelConfigService
@console_ns.route("/apps/<uuid:app_id>/model-config") @console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource): class ModelConfigResource(Resource):
@console_ns.doc("update_app_model_config") @api.doc("update_app_model_config")
@console_ns.doc(description="Update application model configuration") @api.doc(description="Update application model configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"ModelConfigRequest", "ModelConfigRequest",
{ {
"provider": fields.String(description="Model provider"), "provider": fields.String(description="Model provider"),
@ -42,17 +43,20 @@ class ModelConfigResource(Resource):
}, },
) )
) )
@console_ns.response(200, "Model configuration updated successfully") @api.response(200, "Model configuration updated successfully")
@console_ns.response(400, "Invalid configuration") @api.response(400, "Invalid configuration")
@console_ns.response(404, "App not found") @api.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
def post(self, app_model): def post(self, app_model):
"""Modify app model config""" """Modify app model config"""
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
@ -14,18 +14,18 @@ class TraceAppConfigApi(Resource):
Manage trace app configurations Manage trace app configurations
""" """
@console_ns.doc("get_trace_app_config") @api.doc("get_trace_app_config")
@console_ns.doc(description="Get tracing configuration for an application") @api.doc(description="Get tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.parser().add_argument( api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name" "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
) )
) )
@console_ns.response( @api.response(
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data") 200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
) )
@console_ns.response(400, "Invalid request parameters") @api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -41,11 +41,11 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@console_ns.doc("create_trace_app_config") @api.doc("create_trace_app_config")
@console_ns.doc(description="Create a new tracing configuration for an application") @api.doc(description="Create a new tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"TraceConfigCreateRequest", "TraceConfigCreateRequest",
{ {
"tracing_provider": fields.String(required=True, description="Tracing provider name"), "tracing_provider": fields.String(required=True, description="Tracing provider name"),
@ -53,10 +53,10 @@ class TraceAppConfigApi(Resource):
}, },
) )
) )
@console_ns.response( @api.response(
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data") 201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
) )
@console_ns.response(400, "Invalid request parameters or configuration already exists") @api.response(400, "Invalid request parameters or configuration already exists")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -81,11 +81,11 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@console_ns.doc("update_trace_app_config") @api.doc("update_trace_app_config")
@console_ns.doc(description="Update an existing tracing configuration for an application") @api.doc(description="Update an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"TraceConfigUpdateRequest", "TraceConfigUpdateRequest",
{ {
"tracing_provider": fields.String(required=True, description="Tracing provider name"), "tracing_provider": fields.String(required=True, description="Tracing provider name"),
@ -93,8 +93,8 @@ class TraceAppConfigApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response")) @api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
@console_ns.response(400, "Invalid request parameters or configuration not found") @api.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -117,16 +117,16 @@ class TraceAppConfigApi(Resource):
except Exception as e: except Exception as e:
raise BadRequest(str(e)) raise BadRequest(str(e))
@console_ns.doc("delete_trace_app_config") @api.doc("delete_trace_app_config")
@console_ns.doc(description="Delete an existing tracing configuration for an application") @api.doc(description="Delete an existing tracing configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.parser().add_argument( api.parser().add_argument(
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name" "tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
) )
) )
@console_ns.response(204, "Tracing configuration deleted successfully") @api.response(204, "Tracing configuration deleted successfully")
@console_ns.response(400, "Invalid request parameters or configuration not found") @api.response(400, "Invalid request parameters or configuration not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,24 +1,16 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, setup_required
account_initialization_required,
edit_permission_required,
is_admin_or_owner_required,
setup_required,
)
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import Site from models import Site
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
def parse_app_site_args(): def parse_app_site_args():
parser = ( parser = (
@ -51,11 +43,11 @@ def parse_app_site_args():
@console_ns.route("/apps/<uuid:app_id>/site") @console_ns.route("/apps/<uuid:app_id>/site")
class AppSite(Resource): class AppSite(Resource):
@console_ns.doc("update_app_site") @api.doc("update_app_site")
@console_ns.doc(description="Update application site configuration") @api.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"AppSiteRequest", "AppSiteRequest",
{ {
"title": fields.String(description="Site title"), "title": fields.String(description="Site title"),
@ -79,18 +71,22 @@ class AppSite(Resource):
}, },
) )
) )
@console_ns.response(200, "Site configuration updated successfully", app_site_model) @api.response(200, "Site configuration updated successfully", app_site_fields)
@console_ns.response(403, "Insufficient permissions") @api.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found") @api.response(404, "App not found")
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_site_model) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
args = parse_app_site_args() args = parse_app_site_args()
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.has_edit_permission:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound
@ -126,20 +122,24 @@ class AppSite(Resource):
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset") @console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
class AppSiteAccessTokenReset(Resource): class AppSiteAccessTokenReset(Resource):
@console_ns.doc("reset_app_site_access_token") @api.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site") @api.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", app_site_model) @api.response(200, "Access token reset successfully", app_site_fields)
@console_ns.response(403, "Insufficient permissions (admin/owner required)") @api.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found") @api.response(404, "App or site not found")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
@marshal_with(app_site_model) @marshal_with(app_site_fields)
def post(self, app_model): def post(self, app_model):
# The role of the current user in the ta table must be admin or owner
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:

View File

@ -1,48 +1,31 @@
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa import sqlalchemy as sa
from flask import abort, jsonify, request from flask import abort, jsonify
from flask_restx import Resource, fields from flask_restx import Resource, fields, reqparse
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import convert_datetime_to_date from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models import AppMode from models import AppMode, Message
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def empty_string_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
StatisticTimeRangeQuery.__name__,
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
class DailyMessageStatistic(Resource): class DailyMessageStatistic(Resource):
@console_ns.doc("get_daily_message_statistics") @api.doc("get_daily_message_statistics")
@console_ns.doc(description="Get daily message statistics for an application") @api.doc(description="Get daily message statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(
@console_ns.response( api.parser()
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@api.response(
200, 200,
"Daily message statistics retrieved successfully", "Daily message statistics retrieved successfully",
fields.List(fields.Raw(description="Daily message count data")), fields.List(fields.Raw(description="Daily message count data")),
@ -54,11 +37,15 @@ class DailyMessageStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
COUNT(*) AS message_count COUNT(*) AS message_count
FROM FROM
messages messages
@ -69,7 +56,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -93,13 +80,20 @@ WHERE
return jsonify({"data": response_data}) return jsonify({"data": response_data})
parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="Start date (YYYY-MM-DD HH:MM)")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args", help="End date (YYYY-MM-DD HH:MM)")
)
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
class DailyConversationStatistic(Resource): class DailyConversationStatistic(Resource):
@console_ns.doc("get_daily_conversation_statistics") @api.doc("get_daily_conversation_statistics")
@console_ns.doc(description="Get daily conversation statistics for an application") @api.doc(description="Get daily conversation statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Daily conversation statistics retrieved successfully", "Daily conversation statistics retrieved successfully",
fields.List(fields.Raw(description="Daily conversation count data")), fields.List(fields.Raw(description="Daily conversation count data")),
@ -111,51 +105,49 @@ class DailyConversationStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
stmt = (
sa.select(
sa.func.date(
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
).label("date"),
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
)
.select_from(Message)
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER)
)
if start_datetime_utc: if start_datetime_utc:
sql_query += " AND created_at >= :start" stmt = stmt.where(Message.created_at >= start_datetime_utc)
arg_dict["start"] = start_datetime_utc
if end_datetime_utc: if end_datetime_utc:
sql_query += " AND created_at < :end" stmt = stmt.where(Message.created_at < end_datetime_utc)
arg_dict["end"] = end_datetime_utc
sql_query += " GROUP BY date ORDER BY date" stmt = stmt.group_by("date").order_by("date")
response_data = [] response_data = []
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(sa.text(sql_query), arg_dict) rs = conn.execute(stmt, {"tz": account.timezone})
for i in rs: for row in rs:
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
return jsonify({"data": response_data}) return jsonify({"data": response_data})
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users") @console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
class DailyTerminalsStatistic(Resource): class DailyTerminalsStatistic(Resource):
@console_ns.doc("get_daily_terminals_statistics") @api.doc("get_daily_terminals_statistics")
@console_ns.doc(description="Get daily terminal/end-user statistics for an application") @api.doc(description="Get daily terminal/end-user statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Daily terminal statistics retrieved successfully", "Daily terminal statistics retrieved successfully",
fields.List(fields.Raw(description="Daily terminal count data")), fields.List(fields.Raw(description="Daily terminal count data")),
@ -167,11 +159,10 @@ class DailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
FROM FROM
messages messages
@ -182,7 +173,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -208,11 +199,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs") @console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
class DailyTokenCostStatistic(Resource): class DailyTokenCostStatistic(Resource):
@console_ns.doc("get_daily_token_cost_statistics") @api.doc("get_daily_token_cost_statistics")
@console_ns.doc(description="Get daily token cost statistics for an application") @api.doc(description="Get daily token cost statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Daily token cost statistics retrieved successfully", "Daily token cost statistics retrieved successfully",
fields.List(fields.Raw(description="Daily token cost data")), fields.List(fields.Raw(description="Daily token cost data")),
@ -224,11 +215,10 @@ class DailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count, (SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
SUM(total_price) AS total_price SUM(total_price) AS total_price
FROM FROM
@ -240,7 +230,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -268,11 +258,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions") @console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
class AverageSessionInteractionStatistic(Resource): class AverageSessionInteractionStatistic(Resource):
@console_ns.doc("get_average_session_interaction_statistics") @api.doc("get_average_session_interaction_statistics")
@console_ns.doc(description="Get average session interaction statistics for an application") @api.doc(description="Get average session interaction statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Average session interaction statistics retrieved successfully", "Average session interaction statistics retrieved successfully",
fields.List(fields.Raw(description="Average session interaction data")), fields.List(fields.Raw(description="Average session interaction data")),
@ -284,11 +274,10 @@ class AverageSessionInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("c.created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
AVG(subquery.message_count) AS interactions AVG(subquery.message_count) AS interactions
FROM FROM
( (
@ -307,7 +296,7 @@ FROM
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -344,11 +333,11 @@ ORDER BY
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate") @console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
class UserSatisfactionRateStatistic(Resource): class UserSatisfactionRateStatistic(Resource):
@console_ns.doc("get_user_satisfaction_rate_statistics") @api.doc("get_user_satisfaction_rate_statistics")
@console_ns.doc(description="Get user satisfaction rate statistics for an application") @api.doc(description="Get user satisfaction rate statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"User satisfaction rate statistics retrieved successfully", "User satisfaction rate statistics retrieved successfully",
fields.List(fields.Raw(description="User satisfaction rate data")), fields.List(fields.Raw(description="User satisfaction rate data")),
@ -360,11 +349,10 @@ class UserSatisfactionRateStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("m.created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
COUNT(m.id) AS message_count, COUNT(m.id) AS message_count,
COUNT(mf.id) AS feedback_count COUNT(mf.id) AS feedback_count
FROM FROM
@ -379,7 +367,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -410,11 +398,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time") @console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
class AverageResponseTimeStatistic(Resource): class AverageResponseTimeStatistic(Resource):
@console_ns.doc("get_average_response_time_statistics") @api.doc("get_average_response_time_statistics")
@console_ns.doc(description="Get average response time statistics for an application") @api.doc(description="Get average response time statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Average response time statistics retrieved successfully", "Average response time statistics retrieved successfully",
fields.List(fields.Raw(description="Average response time data")), fields.List(fields.Raw(description="Average response time data")),
@ -426,11 +414,10 @@ class AverageResponseTimeStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
AVG(provider_response_latency) AS latency AVG(provider_response_latency) AS latency
FROM FROM
messages messages
@ -441,7 +428,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -467,11 +454,11 @@ WHERE
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second") @console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
class TokensPerSecondStatistic(Resource): class TokensPerSecondStatistic(Resource):
@console_ns.doc("get_tokens_per_second_statistics") @api.doc("get_tokens_per_second_statistics")
@console_ns.doc(description="Get tokens per second statistics for an application") @api.doc(description="Get tokens per second statistics for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__]) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Tokens per second statistics retrieved successfully", "Tokens per second statistics retrieved successfully",
fields.List(fields.Raw(description="Tokens per second data")), fields.List(fields.Raw(description="Tokens per second data")),
@ -482,11 +469,10 @@ class TokensPerSecondStatistic(Resource):
@account_initialization_required @account_initialization_required
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser.parse_args()
converted_created_at = convert_datetime_to_date("created_at") sql_query = """SELECT
sql_query = f"""SELECT DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
{converted_created_at} AS date,
CASE CASE
WHEN SUM(provider_response_latency) = 0 THEN 0 WHEN SUM(provider_response_latency) = 0 THEN 0
ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
@ -500,7 +486,7 @@ WHERE
assert account.timezone is not None assert account.timezone is not None
try: try:
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone) start_datetime_utc, end_datetime_utc = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -1,16 +1,15 @@
import json import json
import logging import logging
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import cast
from flask import abort, request from flask import abort, request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, inputs, marshal_with, reqparse
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@ -33,7 +32,6 @@ from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory, variable_factory from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
@ -50,161 +48,6 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LISTENING_RETRY_IN = 2000 LISTENING_RETRY_IN = 2000
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
from fields.workflow_fields import pipeline_variable_fields, serialize_value_type
conversation_variable_model = console_ns.model(
"ConversationVariable",
{
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw,
"description": fields.String,
},
)
pipeline_variable_model = console_ns.model("PipelineVariable", pipeline_variable_fields)
# Workflow model with nested dependencies
workflow_fields_copy = workflow_fields.copy()
workflow_fields_copy["created_by"] = fields.Nested(simple_account_model, attribute="created_by_account")
workflow_fields_copy["updated_by"] = fields.Nested(
simple_account_model, attribute="updated_by_account", allow_null=True
)
workflow_fields_copy["conversation_variables"] = fields.List(fields.Nested(conversation_variable_model))
workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipeline_variable_model))
workflow_model = console_ns.model("Workflow", workflow_fields_copy)
# Workflow pagination model
workflow_pagination_fields_copy = workflow_pagination_fields.copy()
workflow_pagination_fields_copy["items"] = fields.List(fields.Nested(workflow_model), attribute="items")
workflow_pagination_model = console_ns.model("WorkflowPagination", workflow_pagination_fields_copy)
# Reuse workflow_run_node_execution_model from workflow_run.py if already registered
# Otherwise register it here
from fields.end_user_fields import simple_end_user_fields
simple_end_user_model = None
try:
simple_end_user_model = console_ns.models.get("SimpleEndUser")
except AttributeError:
pass
if simple_end_user_model is None:
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
workflow_run_node_execution_model = None
try:
workflow_run_node_execution_model = console_ns.models.get("WorkflowRunNodeExecution")
except AttributeError:
pass
if workflow_run_node_execution_model is None:
workflow_run_node_execution_model = console_ns.model("WorkflowRunNodeExecution", workflow_run_node_execution_fields)
class SyncDraftWorkflowPayload(BaseModel):
graph: dict[str, Any]
features: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] = Field(default_factory=list)
conversation_variables: list[dict[str, Any]] = Field(default_factory=list)
class BaseWorkflowRunPayload(BaseModel):
files: list[dict[str, Any]] | None = None
class AdvancedChatWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any] | None = None
query: str = ""
conversation_id: str | None = None
parent_message_id: str | None = None
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class IterationNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class LoopNodeRunPayload(BaseModel):
inputs: dict[str, Any] | None = None
class DraftWorkflowRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
class DraftWorkflowNodeRunPayload(BaseWorkflowRunPayload):
inputs: dict[str, Any]
query: str = ""
class PublishWorkflowPayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(SyncDraftWorkflowPayload)
reg(AdvancedChatWorkflowRunPayload)
reg(IterationNodeRunPayload)
reg(LoopNodeRunPayload)
reg(DraftWorkflowRunPayload)
reg(DraftWorkflowNodeRunPayload)
reg(PublishWorkflowPayload)
reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing # TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@ -227,16 +70,16 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence
@console_ns.route("/apps/<uuid:app_id>/workflows/draft") @console_ns.route("/apps/<uuid:app_id>/workflows/draft")
class DraftWorkflowApi(Resource): class DraftWorkflowApi(Resource):
@console_ns.doc("get_draft_workflow") @api.doc("get_draft_workflow")
@console_ns.doc(description="Get draft workflow for an application") @api.doc(description="Get draft workflow for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model) @api.response(200, "Draft workflow retrieved successfully", workflow_fields)
@console_ns.response(404, "Draft workflow not found") @api.response(404, "Draft workflow not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model) @marshal_with(workflow_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
@ -256,13 +99,24 @@ class DraftWorkflowApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@console_ns.doc("sync_draft_workflow") @api.doc("sync_draft_workflow")
@console_ns.doc(description="Sync draft workflow configuration") @api.doc(description="Sync draft workflow configuration")
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__]) @api.expect(
@console_ns.response( api.model(
"SyncDraftWorkflowRequest",
{
"graph": fields.Raw(required=True, description="Workflow graph configuration"),
"features": fields.Raw(required=True, description="Workflow features configuration"),
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@api.response(
200, 200,
"Draft workflow synced successfully", "Draft workflow synced successfully",
console_ns.model( api.model(
"SyncDraftWorkflowResponse", "SyncDraftWorkflowResponse",
{ {
"result": fields.String, "result": fields.String,
@ -271,8 +125,8 @@ class DraftWorkflowApi(Resource):
}, },
), ),
) )
@console_ns.response(400, "Invalid workflow configuration") @api.response(400, "Invalid workflow configuration")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@edit_permission_required @edit_permission_required
def post(self, app_model: App): def post(self, app_model: App):
""" """
@ -282,23 +136,36 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type: if "application/json" in content_type:
payload_data = request.get_json(silent=True) parser = (
if not isinstance(payload_data, dict): reqparse.RequestParser()
return {"message": "Invalid JSON data"}, 400 .add_argument("graph", type=dict, required=True, nullable=False, location="json")
.add_argument("features", type=dict, required=True, nullable=False, location="json")
.add_argument("hash", type=str, required=False, location="json")
.add_argument("environment_variables", type=list, required=True, location="json")
.add_argument("conversation_variables", type=list, required=False, location="json")
)
args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
payload_data = json.loads(request.data.decode("utf-8")) data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict):
raise ValueError("graph or features is not a dict")
args = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
}
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else: else:
abort(415) abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService() workflow_service = WorkflowService()
try: try:
@ -331,13 +198,23 @@ class DraftWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
class AdvancedChatDraftWorkflowRunApi(Resource): class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.doc("run_advanced_chat_draft_workflow") @api.doc("run_advanced_chat_draft_workflow")
@console_ns.doc(description="Run draft workflow for advanced chat application") @api.doc(description="Run draft workflow for advanced chat application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AdvancedChatWorkflowRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Workflow run started successfully") api.model(
@console_ns.response(400, "Invalid request parameters") "AdvancedChatWorkflowRunRequest",
@console_ns.response(403, "Permission denied") {
"query": fields.String(required=True, description="User query"),
"inputs": fields.Raw(description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
"conversation_id": fields.String(description="Conversation ID"),
},
)
)
@api.response(200, "Workflow run started successfully")
@api.response(400, "Invalid request parameters")
@api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -349,8 +226,16 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {}) parser = (
args = args_model.model_dump(exclude_none=True) reqparse.RequestParser()
.add_argument("inputs", type=dict, location="json")
.add_argument("query", type=str, required=True, location="json", default="")
.add_argument("files", type=list, location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -377,13 +262,21 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run")
class AdvancedChatDraftRunIterationNodeApi(Resource): class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_iteration_node") @api.doc("run_advanced_chat_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for advanced chat") @api.doc(description="Run draft workflow iteration node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Iteration node run started successfully") api.model(
@console_ns.response(403, "Permission denied") "IterationNodeRunRequest",
@console_ns.response(404, "Node not found") {
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@api.response(200, "Iteration node run started successfully")
@api.response(403, "Permission denied")
@api.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -394,7 +287,8 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -415,13 +309,21 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class WorkflowDraftRunIterationNodeApi(Resource): class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_workflow_draft_iteration_node") @api.doc("run_workflow_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node") @api.doc(description="Run draft workflow iteration node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[IterationNodeRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Workflow iteration node run started successfully") api.model(
@console_ns.response(403, "Permission denied") "WorkflowIterationNodeRunRequest",
@console_ns.response(404, "Node not found") {
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@api.response(200, "Workflow iteration node run started successfully")
@api.response(403, "Permission denied")
@api.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -432,7 +334,8 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node Run draft workflow iteration node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_iteration( response = AppGenerateService.generate_single_iteration(
@ -453,13 +356,21 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run")
class AdvancedChatDraftRunLoopNodeApi(Resource): class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_advanced_chat_draft_loop_node") @api.doc("run_advanced_chat_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for advanced chat") @api.doc(description="Run draft workflow loop node for advanced chat")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Loop node run started successfully") api.model(
@console_ns.response(403, "Permission denied") "LoopNodeRunRequest",
@console_ns.response(404, "Node not found") {
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@api.response(200, "Loop node run started successfully")
@api.response(403, "Permission denied")
@api.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -470,7 +381,8 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -491,13 +403,21 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class WorkflowDraftRunLoopNodeApi(Resource): class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_workflow_draft_loop_node") @api.doc("run_workflow_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node") @api.doc(description="Run draft workflow loop node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[LoopNodeRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Workflow loop node run started successfully") api.model(
@console_ns.response(403, "Permission denied") "WorkflowLoopNodeRunRequest",
@console_ns.response(404, "Node not found") {
"task_id": fields.String(required=True, description="Task ID"),
"inputs": fields.Raw(description="Input variables"),
},
)
)
@api.response(200, "Workflow loop node run started successfully")
@api.response(403, "Permission denied")
@api.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -508,7 +428,8 @@ class WorkflowDraftRunLoopNodeApi(Resource):
Run draft workflow loop node Run draft workflow loop node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) parser = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
try: try:
response = AppGenerateService.generate_single_loop( response = AppGenerateService.generate_single_loop(
@ -529,12 +450,20 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
class DraftWorkflowRunApi(Resource): class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow") @api.doc("run_draft_workflow")
@console_ns.doc(description="Run draft workflow") @api.doc(description="Run draft workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Draft workflow run started successfully") api.model(
@console_ns.response(403, "Permission denied") "DraftWorkflowRunRequest",
{
"inputs": fields.Raw(required=True, description="Input variables"),
"files": fields.List(fields.Raw, description="File uploads"),
},
)
)
@api.response(200, "Draft workflow run started successfully")
@api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -545,7 +474,12 @@ class DraftWorkflowRunApi(Resource):
Run draft workflow Run draft workflow
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("files", type=list, required=False, location="json")
)
args = parser.parse_args()
external_trace_id = get_external_trace_id(request) external_trace_id = get_external_trace_id(request)
if external_trace_id: if external_trace_id:
@ -567,12 +501,12 @@ class DraftWorkflowRunApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
class WorkflowTaskStopApi(Resource): class WorkflowTaskStopApi(Resource):
@console_ns.doc("stop_workflow_task") @api.doc("stop_workflow_task")
@console_ns.doc(description="Stop running workflow task") @api.doc(description="Stop running workflow task")
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID"}) @api.doc(params={"app_id": "Application ID", "task_id": "Task ID"})
@console_ns.response(200, "Task stopped successfully") @api.response(200, "Task stopped successfully")
@console_ns.response(404, "Task not found") @api.response(404, "Task not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -594,28 +528,40 @@ class WorkflowTaskStopApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
class DraftWorkflowNodeRunApi(Resource): class DraftWorkflowNodeRunApi(Resource):
@console_ns.doc("run_draft_workflow_node") @api.doc("run_draft_workflow_node")
@console_ns.doc(description="Run draft workflow node") @api.doc(description="Run draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[DraftWorkflowNodeRunPayload.__name__]) @api.expect(
@console_ns.response(200, "Node run started successfully", workflow_run_node_execution_model) api.model(
@console_ns.response(403, "Permission denied") "DraftWorkflowNodeRunRequest",
@console_ns.response(404, "Node not found") {
"inputs": fields.Raw(description="Input variables"),
},
)
)
@api.response(200, "Node run started successfully", workflow_run_node_execution_fields)
@api.response(403, "Permission denied")
@api.response(404, "Node not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_model) @marshal_with(workflow_run_node_execution_fields)
@edit_permission_required @edit_permission_required
def post(self, app_model: App, node_id: str): def post(self, app_model: App, node_id: str):
""" """
Run draft workflow node Run draft workflow node
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {}) parser = (
args = args_model.model_dump(exclude_none=True) reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
.add_argument("query", type=str, required=False, location="json", default="")
.add_argument("files", type=list, location="json", default=[])
)
args = parser.parse_args()
user_inputs = args_model.inputs user_inputs = args.get("inputs")
if user_inputs is None: if user_inputs is None:
raise ValueError("missing inputs") raise ValueError("missing inputs")
@ -640,18 +586,25 @@ class DraftWorkflowNodeRunApi(Resource):
return workflow_node_execution return workflow_node_execution
parser_publish = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, default="", location="json")
.add_argument("marked_comment", type=str, required=False, default="", location="json")
)
@console_ns.route("/apps/<uuid:app_id>/workflows/publish") @console_ns.route("/apps/<uuid:app_id>/workflows/publish")
class PublishedWorkflowApi(Resource): class PublishedWorkflowApi(Resource):
@console_ns.doc("get_published_workflow") @api.doc("get_published_workflow")
@console_ns.doc(description="Get published workflow for an application") @api.doc(description="Get published workflow for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model) @api.response(200, "Published workflow retrieved successfully", workflow_fields)
@console_ns.response(404, "Published workflow not found") @api.response(404, "Published workflow not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model) @marshal_with(workflow_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
@ -664,7 +617,7 @@ class PublishedWorkflowApi(Resource):
# return workflow, if not found, return None # return workflow, if not found, return None
return workflow return workflow
@console_ns.expect(console_ns.models[PublishWorkflowPayload.__name__]) @api.expect(parser_publish)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -676,7 +629,13 @@ class PublishedWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) args = parser_publish.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -707,10 +666,10 @@ class PublishedWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
class DefaultBlockConfigsApi(Resource): class DefaultBlockConfigsApi(Resource):
@console_ns.doc("get_default_block_configs") @api.doc("get_default_block_configs")
@console_ns.doc(description="Get default block configurations for workflow") @api.doc(description="Get default block configurations for workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Default block configurations retrieved successfully") @api.response(200, "Default block configurations retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -725,14 +684,17 @@ class DefaultBlockConfigsApi(Resource):
return workflow_service.get_default_block_configs() return workflow_service.get_default_block_configs()
parser_block = reqparse.RequestParser().add_argument("q", type=str, location="args")
@console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultBlockConfigApi(Resource): class DefaultBlockConfigApi(Resource):
@console_ns.doc("get_default_block_config") @api.doc("get_default_block_config")
@console_ns.doc(description="Get default block configuration by type") @api.doc(description="Get default block configuration by type")
@console_ns.doc(params={"app_id": "Application ID", "block_type": "Block type"}) @api.doc(params={"app_id": "Application ID", "block_type": "Block type"})
@console_ns.response(200, "Default block configuration retrieved successfully") @api.response(200, "Default block configuration retrieved successfully")
@console_ns.response(404, "Block type not found") @api.response(404, "Block type not found")
@console_ns.expect(console_ns.models[DefaultBlockConfigQuery.__name__]) @api.expect(parser_block)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -742,12 +704,14 @@ class DefaultBlockConfigApi(Resource):
""" """
Get default block config Get default block config
""" """
args = DefaultBlockConfigQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_block.parse_args()
q = args.get("q")
filters = None filters = None
if args.q: if q:
try: try:
filters = json.loads(args.q) filters = json.loads(args.get("q", ""))
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError("Invalid filters") raise ValueError("Invalid filters")
@ -756,15 +720,24 @@ class DefaultBlockConfigApi(Resource):
return workflow_service.get_default_block_config(node_type=block_type, filters=filters) return workflow_service.get_default_block_config(node_type=block_type, filters=filters)
parser_convert = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
.add_argument("icon_type", type=str, required=False, nullable=True, location="json")
.add_argument("icon", type=str, required=False, nullable=True, location="json")
.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
)
@console_ns.route("/apps/<uuid:app_id>/convert-to-workflow") @console_ns.route("/apps/<uuid:app_id>/convert-to-workflow")
class ConvertToWorkflowApi(Resource): class ConvertToWorkflowApi(Resource):
@console_ns.expect(console_ns.models[ConvertToWorkflowPayload.__name__]) @api.expect(parser_convert)
@console_ns.doc("convert_to_workflow") @api.doc("convert_to_workflow")
@console_ns.doc(description="Convert application to workflow mode") @api.doc(description="Convert application to workflow mode")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Application converted to workflow successfully") @api.response(200, "Application converted to workflow successfully")
@console_ns.response(400, "Application cannot be converted") @api.response(400, "Application cannot be converted")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -778,8 +751,10 @@ class ConvertToWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} if request.data:
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True) args = parser_convert.parse_args()
else:
args = {}
# convert to workflow mode # convert to workflow mode
workflow_service = WorkflowService() workflow_service = WorkflowService()
@ -791,18 +766,27 @@ class ConvertToWorkflowApi(Resource):
} }
parser_workflows = (
reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
.add_argument("user_id", type=str, required=False, location="args")
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
)
@console_ns.route("/apps/<uuid:app_id>/workflows") @console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource): class PublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__]) @api.expect(parser_workflows)
@console_ns.doc("get_all_published_workflows") @api.doc("get_all_published_workflows")
@console_ns.doc(description="Get all published workflows for an application") @api.doc(description="Get all published workflows for an application")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Published workflows retrieved successfully", workflow_pagination_model) @api.response(200, "Published workflows retrieved successfully", workflow_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_pagination_model) @marshal_with(workflow_pagination_fields)
@edit_permission_required @edit_permission_required
def get(self, app_model: App): def get(self, app_model: App):
""" """
@ -810,15 +794,16 @@ class PublishedAllWorkflowApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_workflows.parse_args()
page = args.page page = args["page"]
limit = args.limit limit = args["limit"]
user_id = args.user_id user_id = args.get("user_id")
named_only = args.named_only named_only = args.get("named_only", False)
if user_id: if user_id:
if user_id != current_user.id: if user_id != current_user.id:
raise Forbidden() raise Forbidden()
user_id = cast(str, user_id)
workflow_service = WorkflowService() workflow_service = WorkflowService()
with Session(db.engine) as session: with Session(db.engine) as session:
@ -841,32 +826,51 @@ class PublishedAllWorkflowApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>") @console_ns.route("/apps/<uuid:app_id>/workflows/<string:workflow_id>")
class WorkflowByIdApi(Resource): class WorkflowByIdApi(Resource):
@console_ns.doc("update_workflow_by_id") @api.doc("update_workflow_by_id")
@console_ns.doc(description="Update workflow by ID") @api.doc(description="Update workflow by ID")
@console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"}) @api.doc(params={"app_id": "Application ID", "workflow_id": "Workflow ID"})
@console_ns.expect(console_ns.models[WorkflowUpdatePayload.__name__]) @api.expect(
@console_ns.response(200, "Workflow updated successfully", workflow_model) api.model(
@console_ns.response(404, "Workflow not found") "UpdateWorkflowRequest",
@console_ns.response(403, "Permission denied") {
"environment_variables": fields.List(fields.Raw, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
},
)
)
@api.response(200, "Workflow updated successfully", workflow_fields)
@api.response(404, "Workflow not found")
@api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_model) @marshal_with(workflow_fields)
@edit_permission_required @edit_permission_required
def patch(self, app_model: App, workflow_id: str): def patch(self, app_model: App, workflow_id: str):
""" """
Update workflow attributes Update workflow attributes
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) parser = (
reqparse.RequestParser()
.add_argument("marked_name", type=str, required=False, location="json")
.add_argument("marked_comment", type=str, required=False, location="json")
)
args = parser.parse_args()
# Validate name and comment length
if args.marked_name and len(args.marked_name) > 20:
raise ValueError("Marked name cannot exceed 20 characters")
if args.marked_comment and len(args.marked_comment) > 100:
raise ValueError("Marked comment cannot exceed 100 characters")
# Prepare update data # Prepare update data
update_data = {} update_data = {}
if args.marked_name is not None: if args.get("marked_name") is not None:
update_data["marked_name"] = args.marked_name update_data["marked_name"] = args["marked_name"]
if args.marked_comment is not None: if args.get("marked_comment") is not None:
update_data["marked_comment"] = args.marked_comment update_data["marked_comment"] = args["marked_comment"]
if not update_data: if not update_data:
return {"message": "No valid fields to update"}, 400 return {"message": "No valid fields to update"}, 400
@ -922,17 +926,17 @@ class WorkflowByIdApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run")
class DraftWorkflowNodeLastRunApi(Resource): class DraftWorkflowNodeLastRunApi(Resource):
@console_ns.doc("get_draft_workflow_node_last_run") @api.doc("get_draft_workflow_node_last_run")
@console_ns.doc(description="Get last run result for draft workflow node") @api.doc(description="Get last run result for draft workflow node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model) @api.response(200, "Node last run retrieved successfully", workflow_run_node_execution_fields)
@console_ns.response(404, "Node last run not found") @api.response(404, "Node last run not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_model) @marshal_with(workflow_run_node_execution_fields)
def get(self, app_model: App, node_id: str): def get(self, app_model: App, node_id: str):
srv = WorkflowService() srv = WorkflowService()
workflow = srv.get_draft_workflow(app_model) workflow = srv.get_draft_workflow(app_model)
@ -955,20 +959,20 @@ class DraftWorkflowTriggerRunApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run Path: /apps/<uuid:app_id>/workflows/draft/trigger/run
""" """
@console_ns.doc("poll_draft_workflow_trigger_run") @api.doc("poll_draft_workflow_trigger_run")
@console_ns.doc(description="Poll for trigger events and execute full workflow when event arrives") @api.doc(description="Poll for trigger events and execute full workflow when event arrives")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"DraftWorkflowTriggerRunRequest", "DraftWorkflowTriggerRunRequest",
{ {
"node_id": fields.String(required=True, description="Node ID"), "node_id": fields.String(required=True, description="Node ID"),
}, },
) )
) )
@console_ns.response(200, "Trigger event received and workflow executed successfully") @api.response(200, "Trigger event received and workflow executed successfully")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@console_ns.response(500, "Internal server error") @api.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -979,8 +983,10 @@ class DraftWorkflowTriggerRunApi(Resource):
Poll for trigger events and execute full workflow when event arrives Poll for trigger events and execute full workflow when event arrives
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {}) parser = reqparse.RequestParser()
node_id = args.node_id parser.add_argument("node_id", type=str, required=True, location="json", nullable=False)
args = parser.parse_args()
node_id = args["node_id"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:
@ -1026,12 +1032,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run Path: /apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/trigger/run
""" """
@console_ns.doc("poll_draft_workflow_trigger_node") @api.doc("poll_draft_workflow_trigger_node")
@console_ns.doc(description="Poll for trigger events and execute single node when event arrives") @api.doc(description="Poll for trigger events and execute single node when event arrives")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.response(200, "Trigger event received and node executed successfully") @api.response(200, "Trigger event received and node executed successfully")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@console_ns.response(500, "Internal server error") @api.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1105,13 +1111,20 @@ class DraftWorkflowTriggerRunAllApi(Resource):
Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all Path: /apps/<uuid:app_id>/workflows/draft/trigger/run-all
""" """
@console_ns.doc("draft_workflow_trigger_run_all") @api.doc("draft_workflow_trigger_run_all")
@console_ns.doc(description="Full workflow debug when the start node is a trigger") @api.doc(description="Full workflow debug when the start node is a trigger")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[DraftWorkflowTriggerRunAllPayload.__name__]) @api.expect(
@console_ns.response(200, "Workflow executed successfully") api.model(
@console_ns.response(403, "Permission denied") "DraftWorkflowTriggerRunAllRequest",
@console_ns.response(500, "Internal server error") {
"node_ids": fields.List(fields.String, required=True, description="Node IDs"),
},
)
)
@api.response(200, "Workflow executed successfully")
@api.response(403, "Permission denied")
@api.response(500, "Internal server error")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1123,8 +1136,10 @@ class DraftWorkflowTriggerRunAllApi(Resource):
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {}) parser = reqparse.RequestParser()
node_ids = args.node_ids parser.add_argument("node_ids", type=list, required=True, location="json", nullable=False)
args = parser.parse_args()
node_ids = args["node_ids"]
workflow_service = WorkflowService() workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model) draft_workflow = workflow_service.get_draft_workflow(app_model)
if not draft_workflow: if not draft_workflow:

View File

@ -1,85 +1,86 @@
from datetime import datetime
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with from flask_restx.inputs import int_range
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required from libs.login import login_required
from models import App from models import App
from models.model import AppMode from models.model import AppMode
from services.workflow_app_service import WorkflowAppService from services.workflow_app_service import WorkflowAppService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowAppLogQuery(BaseModel):
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
status: WorkflowExecutionStatus | None = Field(
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
)
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
created_by_account: str | None = Field(default=None, description="Filter by account")
detail: bool = Field(default=False, description="Whether to return detailed logs")
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
@field_validator("created_at__before", "created_at__after", mode="before")
@classmethod
def parse_datetime(cls, value: str | None) -> datetime | None:
if value in (None, ""):
return None
return isoparse(value) # type: ignore
@field_validator("detail", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
lowered = value.lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
raise ValueError("Invalid boolean value for detail")
console_ns.schema_model(
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
# Register model for flask_restx to avoid dict type issues in Swagger
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs") @console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
class WorkflowAppLogApi(Resource): class WorkflowAppLogApi(Resource):
@console_ns.doc("get_workflow_app_logs") @api.doc("get_workflow_app_logs")
@console_ns.doc(description="Get workflow application execution logs") @api.doc(description="Get workflow application execution logs")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__]) @api.doc(
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model) params={
"keyword": "Search keyword for filtering logs",
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
"created_at__before": "Filter logs created before this timestamp",
"created_at__after": "Filter logs created after this timestamp",
"created_by_end_user_session_id": "Filter by end user session ID",
"created_by_account": "Filter by account",
"detail": "Whether to return detailed logs",
"page": "Page number (1-99999)",
"limit": "Number of items per page (1-100)",
}
)
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.WORKFLOW])
@marshal_with(workflow_app_log_pagination_model) @marshal_with(workflow_app_log_pagination_fields)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get workflow app logs Get workflow app logs
""" """
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("keyword", type=str, location="args")
.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)
.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
.add_argument("detail", type=bool, location="args", required=False, default=False)
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
)
args = parser.parse_args()
args.status = WorkflowExecutionStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

View File

@ -1,18 +1,17 @@
import logging import logging
from collections.abc import Callable from typing import NoReturn
from functools import wraps
from typing import NoReturn, ParamSpec, TypeVar
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
DraftWorkflowNotExist, DraftWorkflowNotExist,
) )
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup from core.variables.segment_group import SegmentGroup
@ -22,8 +21,8 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
from extensions.ext_database import db from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type from factories.variable_factory import build_segment_with_type
from libs.login import login_required from libs.login import current_user, login_required
from models import App, AppMode from models import Account, App, AppMode
from models.workflow import WorkflowDraftVariable from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -141,42 +140,8 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items), "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
} }
# Register models for flask_restx to avoid dict type issues in Swagger
workflow_draft_variable_without_value_model = console_ns.model(
"WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
)
workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS) def _api_prerequisite(f):
workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
workflow_draft_env_variable_list_model = console_ns.model(
"WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
)
workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
)
workflow_draft_variable_list_without_value_model = console_ns.model(
"WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
)
workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
workflow_draft_variable_list_fields_copy["items"] = fields.List(
fields.Nested(workflow_draft_variable_model), attribute=_get_items
)
workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs. """Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied: It ensures the following conditions are satisfied:
@ -190,10 +155,11 @@ def _api_prerequisite(f: Callable[P, R]):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f) def wrapper(*args, **kwargs):
def wrapper(*args: P.args, **kwargs: P.kwargs): assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs) return f(*args, **kwargs)
return wrapper return wrapper
@ -201,16 +167,13 @@ def _api_prerequisite(f: Callable[P, R]):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
class WorkflowVariableCollectionApi(Resource): class WorkflowVariableCollectionApi(Resource):
@console_ns.expect(_create_pagination_parser()) @api.doc("get_workflow_variables")
@console_ns.doc("get_workflow_variables") @api.doc(description="Get draft workflow variables")
@console_ns.doc(description="Get draft workflow variables") @api.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
@console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"}) @api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
@console_ns.response(
200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
)
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_list_without_value_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get draft workflow Get draft workflow
@ -237,9 +200,9 @@ class WorkflowVariableCollectionApi(Resource):
return workflow_vars return workflow_vars
@console_ns.doc("delete_workflow_variables") @api.doc("delete_workflow_variables")
@console_ns.doc(description="Delete all draft workflow variables") @api.doc(description="Delete all draft workflow variables")
@console_ns.response(204, "Workflow variables deleted successfully") @api.response(204, "Workflow variables deleted successfully")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App): def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -270,12 +233,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
class NodeVariableCollectionApi(Resource): class NodeVariableCollectionApi(Resource):
@console_ns.doc("get_node_variables") @api.doc("get_node_variables")
@console_ns.doc(description="Get variables for a specific node") @api.doc(description="Get variables for a specific node")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"}) @api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model) @api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_list_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str): def get(self, app_model: App, node_id: str):
validate_node_id(node_id) validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session: with Session(bind=db.engine, expire_on_commit=False) as session:
@ -286,9 +249,9 @@ class NodeVariableCollectionApi(Resource):
return node_vars return node_vars
@console_ns.doc("delete_node_variables") @api.doc("delete_node_variables")
@console_ns.doc(description="Delete all variables for a specific node") @api.doc(description="Delete all variables for a specific node")
@console_ns.response(204, "Node variables deleted successfully") @api.response(204, "Node variables deleted successfully")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App, node_id: str): def delete(self, app_model: App, node_id: str):
validate_node_id(node_id) validate_node_id(node_id)
@ -303,13 +266,13 @@ class VariableApi(Resource):
_PATCH_NAME_FIELD = "name" _PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value" _PATCH_VALUE_FIELD = "value"
@console_ns.doc("get_variable") @api.doc("get_variable")
@console_ns.doc(description="Get a specific workflow variable") @api.doc(description="Get a specific workflow variable")
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model) @api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.response(404, "Variable not found") @api.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str): def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
session=db.session(), session=db.session(),
@ -321,10 +284,10 @@ class VariableApi(Resource):
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable return variable
@console_ns.doc("update_variable") @api.doc("update_variable")
@console_ns.doc(description="Update a workflow variable") @api.doc(description="Update a workflow variable")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"UpdateVariableRequest", "UpdateVariableRequest",
{ {
"name": fields.String(description="Variable name"), "name": fields.String(description="Variable name"),
@ -332,10 +295,10 @@ class VariableApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model) @api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.response(404, "Variable not found") @api.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str): def patch(self, app_model: App, variable_id: str):
# Request payload for file types: # Request payload for file types:
# #
@ -397,10 +360,10 @@ class VariableApi(Resource):
db.session.commit() db.session.commit()
return variable return variable
@console_ns.doc("delete_variable") @api.doc("delete_variable")
@console_ns.doc(description="Delete a workflow variable") @api.doc(description="Delete a workflow variable")
@console_ns.response(204, "Variable deleted successfully") @api.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found") @api.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
def delete(self, app_model: App, variable_id: str): def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -418,12 +381,12 @@ class VariableApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class VariableResetApi(Resource): class VariableResetApi(Resource):
@console_ns.doc("reset_variable") @api.doc("reset_variable")
@console_ns.doc(description="Reset a workflow variable to its default value") @api.doc(description="Reset a workflow variable to its default value")
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"}) @api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model) @api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
@console_ns.response(204, "Variable reset (no content)") @api.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found") @api.response(404, "Variable not found")
@_api_prerequisite @_api_prerequisite
def put(self, app_model: App, variable_id: str): def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService( draft_var_srv = WorkflowDraftVariableService(
@ -447,7 +410,7 @@ class VariableResetApi(Resource):
if resetted is None: if resetted is None:
return Response("", 204) return Response("", 204)
else: else:
return marshal(resetted, workflow_draft_variable_model) return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@ -466,13 +429,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
class ConversationVariableCollectionApi(Resource): class ConversationVariableCollectionApi(Resource):
@console_ns.doc("get_conversation_variables") @api.doc("get_conversation_variables")
@console_ns.doc(description="Get conversation variables for workflow") @api.doc(description="Get conversation variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model) @api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@console_ns.response(404, "Draft workflow not found") @api.response(404, "Draft workflow not found")
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_list_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller. # so their IDs can be returned to the caller.
@ -488,23 +451,23 @@ class ConversationVariableCollectionApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource): class SystemVariableCollectionApi(Resource):
@console_ns.doc("get_system_variables") @api.doc("get_system_variables")
@console_ns.doc(description="Get system variables for workflow") @api.doc(description="Get system variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model) @api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
@_api_prerequisite @_api_prerequisite
@marshal_with(workflow_draft_variable_list_model) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID) return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables") @console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
class EnvironmentVariableCollectionApi(Resource): class EnvironmentVariableCollectionApi(Resource):
@console_ns.doc("get_environment_variables") @api.doc("get_environment_variables")
@console_ns.doc(description="Get environment variables for workflow") @api.doc(description="Get environment variables for workflow")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Environment variables retrieved successfully") @api.response(200, "Environment variables retrieved successfully")
@console_ns.response(404, "Draft workflow not found") @api.response(404, "Draft workflow not found")
@_api_prerequisite @_api_prerequisite
def get(self, app_model: App): def get(self, app_model: App):
""" """

View File

@ -1,21 +1,15 @@
from typing import Literal, cast from typing import cast
from flask import request from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, fields, marshal_with from flask_restx.inputs import int_range
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
from fields.workflow_run_fields import ( from fields.workflow_run_fields import (
advanced_chat_workflow_run_for_list_fields,
advanced_chat_workflow_run_pagination_fields, advanced_chat_workflow_run_pagination_fields,
workflow_run_count_fields, workflow_run_count_fields,
workflow_run_detail_fields, workflow_run_detail_fields,
workflow_run_for_list_fields,
workflow_run_node_execution_fields,
workflow_run_node_execution_list_fields, workflow_run_node_execution_list_fields,
workflow_run_pagination_fields, workflow_run_pagination_fields,
) )
@ -28,148 +22,96 @@ from services.workflow_run_service import WorkflowRunService
# Workflow run status choices for filtering # Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"] WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
# Base models def _parse_workflow_run_list_args():
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields) """
Parse common arguments for workflow run list endpoints.
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields) Returns:
Parsed arguments containing last_id, limit, status, and triggered_from filters
# Models that depend on simple_account_fields """
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy() parser = (
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( reqparse.RequestParser()
simple_account_model, attribute="created_by_account", allow_null=True .add_argument("last_id", type=uuid_value, location="args")
) .add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy) .add_argument(
"status",
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy() type=str,
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested( choices=WORKFLOW_RUN_STATUS_CHOICES,
simple_account_model, attribute="created_by_account", allow_null=True location="args",
) required=False,
advanced_chat_workflow_run_for_list_model = console_ns.model( )
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy .add_argument(
) "triggered_from",
type=str,
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy() choices=["debugging", "app-run"],
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested( location="args",
simple_account_model, attribute="created_by_account", allow_null=True required=False,
) help="Filter by trigger source: debugging or app-run",
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested( )
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
simple_account_model, attribute="created_by_account", allow_null=True
)
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
simple_end_user_model, attribute="created_by_end_user", allow_null=True
)
workflow_run_node_execution_model = console_ns.model(
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
)
# Simple models without nested dependencies
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
# Pagination models that depend on list models
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
)
advanced_chat_workflow_run_pagination_model = console_ns.model(
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
)
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
workflow_run_node_execution_list_model = console_ns.model(
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunListQuery(BaseModel):
last_id: str | None = Field(default=None, description="Last run ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
) )
triggered_from: Literal["debugging", "app-run"] | None = Field( return parser.parse_args()
default=None, description="Filter by trigger source: debugging or app-run"
def _parse_workflow_run_count_args():
"""
Parse common arguments for workflow run count endpoints.
Returns:
Parsed arguments containing status, time_range, and triggered_from filters
"""
parser = (
reqparse.RequestParser()
.add_argument(
"status",
type=str,
choices=WORKFLOW_RUN_STATUS_CHOICES,
location="args",
required=False,
)
.add_argument(
"time_range",
type=time_duration,
location="args",
required=False,
help="Time range filter (e.g., 7d, 4h, 30m, 30s)",
)
.add_argument(
"triggered_from",
type=str,
choices=["debugging", "app-run"],
location="args",
required=False,
help="Filter by trigger source: debugging or app-run",
)
) )
return parser.parse_args()
@field_validator("last_id")
@classmethod
def validate_last_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class WorkflowRunCountQuery(BaseModel):
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
default=None, description="Workflow run status filter"
)
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
triggered_from: Literal["debugging", "app-run"] | None = Field(
default=None, description="Filter by trigger source: debugging or app-run"
)
@field_validator("time_range")
@classmethod
def validate_time_range(cls, value: str | None) -> str | None:
if value is None:
return value
return time_duration(value)
console_ns.schema_model(
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
WorkflowRunCountQuery.__name__,
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
class AdvancedChatAppWorkflowRunListApi(Resource): class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs") @api.doc("get_advanced_chat_workflow_runs")
@console_ns.doc(description="Get advanced chat workflow run list") @api.doc(description="Get advanced chat workflow run list")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@console_ns.doc( @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
) @api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(advanced_chat_workflow_run_pagination_model) @marshal_with(advanced_chat_workflow_run_pagination_fields)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get advanced chat app workflow run list Get advanced chat app workflow run list
""" """
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = _parse_workflow_run_list_args()
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from) WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args_model.triggered_from if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -183,13 +125,11 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count") @console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
class AdvancedChatAppWorkflowRunCountApi(Resource): class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.doc("get_advanced_chat_workflow_runs_count") @api.doc("get_advanced_chat_workflow_runs_count")
@console_ns.doc(description="Get advanced chat workflow runs count statistics") @api.doc(description="Get advanced chat workflow runs count statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.doc( @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} @api.doc(
)
@console_ns.doc(
params={ params={
"time_range": ( "time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@ -197,27 +137,23 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
) )
} }
) )
@console_ns.doc( @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT]) @get_app_model(mode=[AppMode.ADVANCED_CHAT])
@marshal_with(workflow_run_count_model) @marshal_with(workflow_run_count_fields)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get advanced chat workflow runs count statistics Get advanced chat workflow runs count statistics
""" """
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = _parse_workflow_run_count_args()
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified # Default to DEBUGGING if not specified
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from) WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args_model.triggered_from if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -234,34 +170,28 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs") @console_ns.route("/apps/<uuid:app_id>/workflow-runs")
class WorkflowRunListApi(Resource): class WorkflowRunListApi(Resource):
@console_ns.doc("get_workflow_runs") @api.doc("get_workflow_runs")
@console_ns.doc(description="Get workflow run list") @api.doc(description="Get workflow run list")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"}) @api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
@console_ns.doc( @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
) @api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
@console_ns.doc(
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
)
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_pagination_model) @marshal_with(workflow_run_pagination_fields)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get workflow run list Get workflow run list
""" """
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = _parse_workflow_run_list_args()
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from) WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args_model.triggered_from if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -275,13 +205,11 @@ class WorkflowRunListApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
class WorkflowRunCountApi(Resource): class WorkflowRunCountApi(Resource):
@console_ns.doc("get_workflow_runs_count") @api.doc("get_workflow_runs_count")
@console_ns.doc(description="Get workflow runs count statistics") @api.doc(description="Get workflow runs count statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.doc( @api.doc(params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"})
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"} @api.doc(
)
@console_ns.doc(
params={ params={
"time_range": ( "time_range": (
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), " "Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
@ -289,27 +217,23 @@ class WorkflowRunCountApi(Resource):
) )
} }
) )
@console_ns.doc( @api.doc(params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"})
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"} @api.response(200, "Workflow runs count retrieved successfully", workflow_run_count_fields)
)
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_count_model) @marshal_with(workflow_run_count_fields)
def get(self, app_model: App): def get(self, app_model: App):
""" """
Get workflow runs count statistics Get workflow runs count statistics
""" """
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = _parse_workflow_run_count_args()
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility) # Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom(args_model.triggered_from) WorkflowRunTriggeredFrom(args.get("triggered_from"))
if args_model.triggered_from if args.get("triggered_from")
else WorkflowRunTriggeredFrom.DEBUGGING else WorkflowRunTriggeredFrom.DEBUGGING
) )
@ -326,16 +250,16 @@ class WorkflowRunCountApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
class WorkflowRunDetailApi(Resource): class WorkflowRunDetailApi(Resource):
@console_ns.doc("get_workflow_run_detail") @api.doc("get_workflow_run_detail")
@console_ns.doc(description="Get workflow run detail") @api.doc(description="Get workflow run detail")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model) @api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
@console_ns.response(404, "Workflow run not found") @api.response(404, "Workflow run not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_detail_model) @marshal_with(workflow_run_detail_fields)
def get(self, app_model: App, run_id): def get(self, app_model: App, run_id):
""" """
Get workflow run detail Get workflow run detail
@ -350,16 +274,16 @@ class WorkflowRunDetailApi(Resource):
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions") @console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
class WorkflowRunNodeExecutionListApi(Resource): class WorkflowRunNodeExecutionListApi(Resource):
@console_ns.doc("get_workflow_run_node_executions") @api.doc("get_workflow_run_node_executions")
@console_ns.doc(description="Get workflow run node execution list") @api.doc(description="Get workflow run node execution list")
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"}) @api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model) @api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
@console_ns.response(404, "Workflow run not found") @api.response(404, "Workflow run not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_list_model) @marshal_with(workflow_run_node_execution_list_fields)
def get(self, app_model: App, run_id): def get(self, app_model: App, run_id):
""" """
Get workflow run node execution list Get workflow run node execution list

View File

@ -1,38 +1,18 @@
from flask import abort, jsonify, request from flask import abort, jsonify
from flask_restx import Resource from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import parse_time_range from libs.datetime_utils import parse_time_range
from libs.helper import DatetimeString
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory from repositories.factory import DifyAPIRepositoryFactory
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowStatisticQuery(BaseModel):
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
@field_validator("start", "end", mode="before")
@classmethod
def blank_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
console_ns.schema_model(
WorkflowStatisticQuery.__name__,
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations") @console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
class WorkflowDailyRunsStatistic(Resource): class WorkflowDailyRunsStatistic(Resource):
@ -41,11 +21,11 @@ class WorkflowDailyRunsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_runs_statistic") @api.doc("get_workflow_daily_runs_statistic")
@console_ns.doc(description="Get workflow daily runs statistics") @api.doc(description="Get workflow daily runs statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@console_ns.response(200, "Daily runs statistics retrieved successfully") @api.response(200, "Daily runs statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -53,12 +33,17 @@ class WorkflowDailyRunsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone) start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -81,11 +66,11 @@ class WorkflowDailyTerminalsStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_terminals_statistic") @api.doc("get_workflow_daily_terminals_statistic")
@console_ns.doc(description="Get workflow daily terminals statistics") @api.doc(description="Get workflow daily terminals statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@console_ns.response(200, "Daily terminals statistics retrieved successfully") @api.response(200, "Daily terminals statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -93,12 +78,17 @@ class WorkflowDailyTerminalsStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone) start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -121,11 +111,11 @@ class WorkflowDailyTokenCostStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_daily_token_cost_statistic") @api.doc("get_workflow_daily_token_cost_statistic")
@console_ns.doc(description="Get workflow daily token cost statistics") @api.doc(description="Get workflow daily token cost statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@console_ns.response(200, "Daily token cost statistics retrieved successfully") @api.response(200, "Daily token cost statistics retrieved successfully")
@get_app_model @get_app_model
@setup_required @setup_required
@login_required @login_required
@ -133,12 +123,17 @@ class WorkflowDailyTokenCostStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone) start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))
@ -161,11 +156,11 @@ class WorkflowAverageAppInteractionStatistic(Resource):
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
@console_ns.doc("get_workflow_average_app_interaction_statistic") @api.doc("get_workflow_average_app_interaction_statistic")
@console_ns.doc(description="Get workflow average app interaction statistics") @api.doc(description="Get workflow average app interaction statistics")
@console_ns.doc(params={"app_id": "Application ID"}) @api.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__]) @api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
@console_ns.response(200, "Average app interaction statistics retrieved successfully") @api.response(200, "Average app interaction statistics retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -173,12 +168,17 @@ class WorkflowAverageAppInteractionStatistic(Resource):
def get(self, app_model): def get(self, app_model):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
)
args = parser.parse_args()
assert account.timezone is not None assert account.timezone is not None
try: try:
start_date, end_date = parse_time_range(args.start, args.end, account.timezone) start_date, end_date = parse_time_range(args["start"], args["end"], account.timezone)
except ValueError as e: except ValueError as e:
abort(400, description=str(e)) abort(400, description=str(e))

View File

@ -1,13 +1,14 @@
import logging import logging
from flask import request from flask_restx import Resource, marshal_with, reqparse
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -15,35 +16,12 @@ from models.enums import AppTriggerStatus
from models.model import Account, App, AppMode from models.model import Account, App, AppMode
from models.trigger import AppTrigger, WorkflowWebhookTrigger from models.trigger import AppTrigger, WorkflowWebhookTrigger
from .. import console_ns
from ..app.wraps import get_app_model
from ..wraps import account_initialization_required, edit_permission_required, setup_required
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class Parser(BaseModel):
node_id: str
class ParserEnable(BaseModel):
trigger_id: str
enable_trigger: bool
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
console_ns.schema_model(
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
class WebhookTriggerApi(Resource): class WebhookTriggerApi(Resource):
"""Webhook Trigger API""" """Webhook Trigger API"""
@console_ns.expect(console_ns.models[Parser.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -51,9 +29,11 @@ class WebhookTriggerApi(Resource):
@marshal_with(webhook_trigger_fields) @marshal_with(webhook_trigger_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Get webhook trigger for a node""" """Get webhook trigger for a node"""
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = reqparse.RequestParser()
parser.add_argument("node_id", type=str, required=True, help="Node ID is required")
args = parser.parse_args()
node_id = args.node_id node_id = str(args["node_id"])
with Session(db.engine) as session: with Session(db.engine) as session:
# Get webhook trigger for this app and node # Get webhook trigger for this app and node
@ -72,7 +52,6 @@ class WebhookTriggerApi(Resource):
return webhook_trigger return webhook_trigger
@console_ns.route("/apps/<uuid:app_id>/triggers")
class AppTriggersApi(Resource): class AppTriggersApi(Resource):
"""App Triggers list API""" """App Triggers list API"""
@ -112,22 +91,26 @@ class AppTriggersApi(Resource):
return {"data": triggers} return {"data": triggers}
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
class AppTriggerEnableApi(Resource): class AppTriggerEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.WORKFLOW) @get_app_model(mode=AppMode.WORKFLOW)
@marshal_with(trigger_fields) @marshal_with(trigger_fields)
def post(self, app_model: App): def post(self, app_model: App):
"""Update app trigger (enable/disable)""" """Update app trigger (enable/disable)"""
args = ParserEnable.model_validate(console_ns.payload) parser = reqparse.RequestParser()
parser.add_argument("trigger_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("enable_trigger", type=bool, required=True, nullable=False, location="json")
args = parser.parse_args()
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
if not current_user.has_edit_permission:
raise Forbidden()
trigger_id = args["trigger_id"]
trigger_id = args.trigger_id
with Session(db.engine) as session: with Session(db.engine) as session:
# Find the trigger using select # Find the trigger using select
trigger = session.execute( trigger = session.execute(
@ -142,7 +125,7 @@ class AppTriggerEnableApi(Resource):
raise NotFound("Trigger not found") raise NotFound("Trigger not found")
# Update status based on enable_trigger boolean # Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED trigger.status = AppTriggerStatus.ENABLED if args["enable_trigger"] else AppTriggerStatus.DISABLED
session.commit() session.commit()
session.refresh(trigger) session.refresh(trigger)
@ -155,3 +138,8 @@ class AppTriggerEnableApi(Resource):
trigger.icon = "" # type: ignore trigger.icon = "" # type: ignore
return trigger return trigger
api.add_resource(WebhookTriggerApi, "/apps/<uuid:app_id>/workflows/triggers/webhook")
api.add_resource(AppTriggersApi, "/apps/<uuid:app_id>/triggers")
api.add_resource(AppTriggerEnableApi, "/apps/<uuid:app_id>/trigger-enable")

View File

@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -20,13 +20,13 @@ active_check_parser = (
@console_ns.route("/activate/check") @console_ns.route("/activate/check")
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@console_ns.doc("check_activation_token") @api.doc("check_activation_token")
@console_ns.doc(description="Check if activation token is valid") @api.doc(description="Check if activation token is valid")
@console_ns.expect(active_check_parser) @api.expect(active_check_parser)
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model(
"ActivationCheckResponse", "ActivationCheckResponse",
{ {
"is_valid": fields.Boolean(description="Whether token is valid"), "is_valid": fields.Boolean(description="Whether token is valid"),
@ -69,13 +69,13 @@ active_parser = (
@console_ns.route("/activate") @console_ns.route("/activate")
class ActivateApi(Resource): class ActivateApi(Resource):
@console_ns.doc("activate_account") @api.doc("activate_account")
@console_ns.doc(description="Activate account with invitation token") @api.doc(description="Activate account with invitation token")
@console_ns.expect(active_parser) @api.expect(active_parser)
@console_ns.response( @api.response(
200, 200,
"Account activated successfully", "Account activated successfully",
console_ns.model( api.model(
"ActivationResponse", "ActivationResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
@ -83,7 +83,7 @@ class ActivateApi(Resource):
}, },
), ),
) )
@console_ns.response(400, "Already activated or invalid token") @api.response(400, "Already activated or invalid token")
def post(self): def post(self):
args = active_parser.parse_args() args = active_parser.parse_args()

View File

@ -1,8 +1,8 @@
from flask_restx import Resource, reqparse from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.auth.error import ApiKeyAuthFailedError from controllers.console.auth.error import ApiKeyAuthFailedError
from controllers.console.wraps import is_admin_or_owner_required
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@ -39,10 +39,12 @@ class ApiKeyAuthDataSourceBinding(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required
def post(self): def post(self):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("category", type=str, required=True, nullable=False, location="json") .add_argument("category", type=str, required=True, nullable=False, location="json")
@ -63,10 +65,12 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_admin_or_owner_required
def delete(self, binding_id): def delete(self, binding_id):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)

View File

@ -3,11 +3,11 @@ import logging
import httpx import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource, fields from flask_restx import Resource, fields
from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import is_admin_or_owner_required from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required from ..wraps import account_initialization_required, setup_required
@ -29,22 +29,24 @@ def get_oauth_providers():
@console_ns.route("/oauth/data-source/<string:provider>") @console_ns.route("/oauth/data-source/<string:provider>")
class OAuthDataSource(Resource): class OAuthDataSource(Resource):
@console_ns.doc("oauth_data_source") @api.doc("oauth_data_source")
@console_ns.doc(description="Get OAuth authorization URL for data source provider") @api.doc(description="Get OAuth authorization URL for data source provider")
@console_ns.doc(params={"provider": "Data source provider name (notion)"}) @api.doc(params={"provider": "Data source provider name (notion)"})
@console_ns.response( @api.response(
200, 200,
"Authorization URL or internal setup success", "Authorization URL or internal setup success",
console_ns.model( api.model(
"OAuthDataSourceResponse", "OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")}, {"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
), ),
) )
@console_ns.response(400, "Invalid provider") @api.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@is_admin_or_owner_required
def get(self, provider: str): def get(self, provider: str):
# The role of the current user in the table must be admin or owner # The role of the current user in the table must be admin or owner
current_user, _ = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
@ -63,17 +65,17 @@ class OAuthDataSource(Resource):
@console_ns.route("/oauth/data-source/callback/<string:provider>") @console_ns.route("/oauth/data-source/callback/<string:provider>")
class OAuthDataSourceCallback(Resource): class OAuthDataSourceCallback(Resource):
@console_ns.doc("oauth_data_source_callback") @api.doc("oauth_data_source_callback")
@console_ns.doc(description="Handle OAuth callback from data source provider") @api.doc(description="Handle OAuth callback from data source provider")
@console_ns.doc( @api.doc(
params={ params={
"provider": "Data source provider name (notion)", "provider": "Data source provider name (notion)",
"code": "Authorization code from OAuth provider", "code": "Authorization code from OAuth provider",
"error": "Error message from OAuth provider", "error": "Error message from OAuth provider",
} }
) )
@console_ns.response(302, "Redirect to console with result") @api.response(302, "Redirect to console with result")
@console_ns.response(400, "Invalid provider") @api.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -94,17 +96,17 @@ class OAuthDataSourceCallback(Resource):
@console_ns.route("/oauth/data-source/binding/<string:provider>") @console_ns.route("/oauth/data-source/binding/<string:provider>")
class OAuthDataSourceBinding(Resource): class OAuthDataSourceBinding(Resource):
@console_ns.doc("oauth_data_source_binding") @api.doc("oauth_data_source_binding")
@console_ns.doc(description="Bind OAuth data source with authorization code") @api.doc(description="Bind OAuth data source with authorization code")
@console_ns.doc( @api.doc(
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"} params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
) )
@console_ns.response( @api.response(
200, 200,
"Data source binding success", "Data source binding success",
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}), api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
) )
@console_ns.response(400, "Invalid provider or code") @api.response(400, "Invalid provider or code")
def get(self, provider: str): def get(self, provider: str):
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
@ -128,15 +130,15 @@ class OAuthDataSourceBinding(Resource):
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") @console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
class OAuthDataSourceSync(Resource): class OAuthDataSourceSync(Resource):
@console_ns.doc("oauth_data_source_sync") @api.doc("oauth_data_source_sync")
@console_ns.doc(description="Sync data from OAuth data source") @api.doc(description="Sync data from OAuth data source")
@console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"}) @api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
@console_ns.response( @api.response(
200, 200,
"Data source sync success", "Data source sync success",
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}), api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
) )
@console_ns.response(400, "Invalid provider or sync failed") @api.response(400, "Invalid provider or sync failed")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -6,7 +6,7 @@ from flask_restx import Resource, fields, reqparse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError, EmailCodeError,
EmailPasswordResetLimitError, EmailPasswordResetLimitError,
@ -27,10 +27,10 @@ from services.feature_service import FeatureService
@console_ns.route("/forgot-password") @console_ns.route("/forgot-password")
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@console_ns.doc("send_forgot_password_email") @api.doc("send_forgot_password_email")
@console_ns.doc(description="Send password reset email") @api.doc(description="Send password reset email")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"ForgotPasswordEmailRequest", "ForgotPasswordEmailRequest",
{ {
"email": fields.String(required=True, description="Email address"), "email": fields.String(required=True, description="Email address"),
@ -38,10 +38,10 @@ class ForgotPasswordSendEmailApi(Resource):
}, },
) )
) )
@console_ns.response( @api.response(
200, 200,
"Email sent successfully", "Email sent successfully",
console_ns.model( api.model(
"ForgotPasswordEmailResponse", "ForgotPasswordEmailResponse",
{ {
"result": fields.String(description="Operation result"), "result": fields.String(description="Operation result"),
@ -50,7 +50,7 @@ class ForgotPasswordSendEmailApi(Resource):
}, },
), ),
) )
@console_ns.response(400, "Invalid email or rate limit exceeded") @api.response(400, "Invalid email or rate limit exceeded")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -85,10 +85,10 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.route("/forgot-password/validity") @console_ns.route("/forgot-password/validity")
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@console_ns.doc("check_forgot_password_code") @api.doc("check_forgot_password_code")
@console_ns.doc(description="Verify password reset code") @api.doc(description="Verify password reset code")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"ForgotPasswordCheckRequest", "ForgotPasswordCheckRequest",
{ {
"email": fields.String(required=True, description="Email address"), "email": fields.String(required=True, description="Email address"),
@ -97,10 +97,10 @@ class ForgotPasswordCheckApi(Resource):
}, },
) )
) )
@console_ns.response( @api.response(
200, 200,
"Code verified successfully", "Code verified successfully",
console_ns.model( api.model(
"ForgotPasswordCheckResponse", "ForgotPasswordCheckResponse",
{ {
"is_valid": fields.Boolean(description="Whether code is valid"), "is_valid": fields.Boolean(description="Whether code is valid"),
@ -109,7 +109,7 @@ class ForgotPasswordCheckApi(Resource):
}, },
), ),
) )
@console_ns.response(400, "Invalid code or token") @api.response(400, "Invalid code or token")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):
@ -152,10 +152,10 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.route("/forgot-password/resets") @console_ns.route("/forgot-password/resets")
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@console_ns.doc("reset_password") @api.doc("reset_password")
@console_ns.doc(description="Reset password with verification token") @api.doc(description="Reset password with verification token")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"ForgotPasswordResetRequest", "ForgotPasswordResetRequest",
{ {
"token": fields.String(required=True, description="Verification token"), "token": fields.String(required=True, description="Verification token"),
@ -164,12 +164,12 @@ class ForgotPasswordResetApi(Resource):
}, },
) )
) )
@console_ns.response( @api.response(
200, 200,
"Password reset successfully", "Password reset successfully",
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}), api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
) )
@console_ns.response(400, "Invalid token or password mismatch") @api.response(400, "Invalid token or password mismatch")
@setup_required @setup_required
@email_password_login_enabled @email_password_login_enabled
def post(self): def post(self):

View File

@ -26,7 +26,7 @@ from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from .. import console_ns from .. import api, console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,13 +56,11 @@ def get_oauth_providers():
@console_ns.route("/oauth/login/<provider>") @console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource): class OAuthLogin(Resource):
@console_ns.doc("oauth_login") @api.doc("oauth_login")
@console_ns.doc(description="Initiate OAuth login process") @api.doc(description="Initiate OAuth login process")
@console_ns.doc( @api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"} @api.response(302, "Redirect to OAuth authorization URL")
) @api.response(400, "Invalid provider")
@console_ns.response(302, "Redirect to OAuth authorization URL")
@console_ns.response(400, "Invalid provider")
def get(self, provider: str): def get(self, provider: str):
invite_token = request.args.get("invite_token") or None invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
@ -77,17 +75,17 @@ class OAuthLogin(Resource):
@console_ns.route("/oauth/authorize/<provider>") @console_ns.route("/oauth/authorize/<provider>")
class OAuthCallback(Resource): class OAuthCallback(Resource):
@console_ns.doc("oauth_callback") @api.doc("oauth_callback")
@console_ns.doc(description="Handle OAuth callback and complete login process") @api.doc(description="Handle OAuth callback and complete login process")
@console_ns.doc( @api.doc(
params={ params={
"provider": "OAuth provider name (github/google)", "provider": "OAuth provider name (github/google)",
"code": "Authorization code from OAuth provider", "code": "Authorization code from OAuth provider",
"state": "Optional state parameter (used for invite token)", "state": "Optional state parameter (used for invite token)",
} }
) )
@console_ns.response(302, "Redirect to console with access token") @api.response(302, "Redirect to console with access token")
@console_ns.response(400, "OAuth process failed") @api.response(400, "OAuth process failed")
def get(self, provider: str): def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():

View File

@ -1,7 +1,4 @@
import base64 from flask_restx import Resource, reqparse
from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
@ -44,37 +41,3 @@ class Invoices(Resource):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
BillingService.is_tenant_owner_or_admin(current_user) BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_invoices(current_user.email, current_tenant_id) return BillingService.get_invoices(current_user.email, current_tenant_id)
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
class PartnerTenants(Resource):
@console_ns.doc("sync_partner_tenants_bindings")
@console_ns.doc(description="Sync partner tenants bindings")
@console_ns.doc(params={"partner_key": "Partner key"})
@console_ns.expect(
console_ns.model(
"SyncPartnerTenantsBindingsRequest",
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
)
)
@console_ns.response(200, "Tenants synced to partner successfully")
@console_ns.response(400, "Invalid partner information")
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def put(self, partner_key: str):
current_user, _ = current_account_with_tenant()
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
args = parser.parse_args()
try:
click_id = args["click_id"]
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
except Exception:
raise BadRequest("Invalid partner_key")
if not click_id or not decoded_partner_key or not current_user.id:
raise BadRequest("Invalid partner information")
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)

View File

@ -7,18 +7,14 @@ from werkzeug.exceptions import Forbidden, NotFound
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.apikey import ( from controllers.console.apikey import api_key_fields, api_key_list
api_key_item_model,
api_key_list_model,
)
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check, cloud_edition_billing_rate_limit_check,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -30,22 +26,8 @@ from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.app_fields import related_app_list
from fields.dataset_fields import ( from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from fields.document_fields import document_status_fields from fields.document_fields import document_status_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length from libs.validators import validate_description_length
@ -55,58 +37,6 @@ from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = _get_or_create_model("DatasetBase", dataset_fields)
tag_model = _get_or_create_model("Tag", tag_fields)
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
dataset_query_detail_model = _get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields)
app_detail_kernel_model = _get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = _get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_name(name: str) -> str: def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40: if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.") raise ValueError("Name must be between 1 to 40 characters.")
@ -188,9 +118,9 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
@console_ns.route("/datasets") @console_ns.route("/datasets")
class DatasetListApi(Resource): class DatasetListApi(Resource):
@console_ns.doc("get_datasets") @api.doc("get_datasets")
@console_ns.doc(description="Get list of datasets") @api.doc(description="Get list of datasets")
@console_ns.doc( @api.doc(
params={ params={
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)", "limit": "Number of items per page (default: 20)",
@ -200,7 +130,7 @@ class DatasetListApi(Resource):
"include_all": "Include all datasets (default: false)", "include_all": "Include all datasets (default: false)",
} }
) )
@console_ns.response(200, "Datasets retrieved successfully") @api.response(200, "Datasets retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -253,10 +183,10 @@ class DatasetListApi(Resource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@console_ns.doc("create_dataset") @api.doc("create_dataset")
@console_ns.doc(description="Create a new dataset") @api.doc(description="Create a new dataset")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"CreateDatasetRequest", "CreateDatasetRequest",
{ {
"name": fields.String(required=True, description="Dataset name (1-40 characters)"), "name": fields.String(required=True, description="Dataset name (1-40 characters)"),
@ -269,8 +199,8 @@ class DatasetListApi(Resource):
}, },
) )
) )
@console_ns.response(201, "Dataset created successfully") @api.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters") @api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -348,12 +278,12 @@ class DatasetListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>") @console_ns.route("/datasets/<uuid:dataset_id>")
class DatasetApi(Resource): class DatasetApi(Resource):
@console_ns.doc("get_dataset") @api.doc("get_dataset")
@console_ns.doc(description="Get dataset details") @api.doc(description="Get dataset details")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model) @api.response(200, "Dataset retrieved successfully", dataset_detail_fields)
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -397,10 +327,10 @@ class DatasetApi(Resource):
return data, 200 return data, 200
@console_ns.doc("update_dataset") @api.doc("update_dataset")
@console_ns.doc(description="Update dataset details") @api.doc(description="Update dataset details")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"UpdateDatasetRequest", "UpdateDatasetRequest",
{ {
"name": fields.String(description="Dataset name"), "name": fields.String(description="Dataset name"),
@ -411,9 +341,9 @@ class DatasetApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model) @api.response(200, "Dataset updated successfully", dataset_detail_fields)
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -557,10 +487,10 @@ class DatasetApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/use-check") @console_ns.route("/datasets/<uuid:dataset_id>/use-check")
class DatasetUseCheckApi(Resource): class DatasetUseCheckApi(Resource):
@console_ns.doc("check_dataset_use") @api.doc("check_dataset_use")
@console_ns.doc(description="Check if dataset is in use") @api.doc(description="Check if dataset is in use")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Dataset use status retrieved successfully") @api.response(200, "Dataset use status retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -573,10 +503,10 @@ class DatasetUseCheckApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/queries") @console_ns.route("/datasets/<uuid:dataset_id>/queries")
class DatasetQueryApi(Resource): class DatasetQueryApi(Resource):
@console_ns.doc("get_dataset_queries") @api.doc("get_dataset_queries")
@console_ns.doc(description="Get dataset query history") @api.doc(description="Get dataset query history")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model) @api.response(200, "Query history retrieved successfully", dataset_query_detail_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -598,7 +528,7 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = { response = {
"data": marshal(dataset_queries, dataset_query_detail_model), "data": marshal(dataset_queries, dataset_query_detail_fields),
"has_more": len(dataset_queries) == limit, "has_more": len(dataset_queries) == limit,
"limit": limit, "limit": limit,
"total": total, "total": total,
@ -609,9 +539,9 @@ class DatasetQueryApi(Resource):
@console_ns.route("/datasets/indexing-estimate") @console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource): class DatasetIndexingEstimateApi(Resource):
@console_ns.doc("estimate_dataset_indexing") @api.doc("estimate_dataset_indexing")
@console_ns.doc(description="Estimate dataset indexing cost") @api.doc(description="Estimate dataset indexing cost")
@console_ns.response(200, "Indexing estimate calculated successfully") @api.response(200, "Indexing estimate calculated successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -719,14 +649,14 @@ class DatasetIndexingEstimateApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/related-apps") @console_ns.route("/datasets/<uuid:dataset_id>/related-apps")
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@console_ns.doc("get_dataset_related_apps") @api.doc("get_dataset_related_apps")
@console_ns.doc(description="Get applications related to dataset") @api.doc(description="Get applications related to dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model) @api.response(200, "Related apps retrieved successfully", related_app_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(related_app_list_model) @marshal_with(related_app_list)
def get(self, dataset_id): def get(self, dataset_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
@ -752,10 +682,10 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status") @console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
class DatasetIndexingStatusApi(Resource): class DatasetIndexingStatusApi(Resource):
@console_ns.doc("get_dataset_indexing_status") @api.doc("get_dataset_indexing_status")
@console_ns.doc(description="Get dataset indexing status") @api.doc(description="Get dataset indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Indexing status retrieved successfully") @api.response(200, "Indexing status retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -807,13 +737,13 @@ class DatasetApiKeyApi(Resource):
token_prefix = "dataset-" token_prefix = "dataset-"
resource_type = "dataset" resource_type = "dataset"
@console_ns.doc("get_dataset_api_keys") @api.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys") @api.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model) @api.response(200, "API keys retrieved successfully", api_key_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_list_model) @marshal_with(api_key_list)
def get(self): def get(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars( keys = db.session.scalars(
@ -823,11 +753,13 @@ class DatasetApiKeyApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
@marshal_with(api_key_item_model) @marshal_with(api_key_fields)
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() # The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
@ -836,7 +768,7 @@ class DatasetApiKeyApi(Resource):
) )
if current_key_count >= self.max_keys: if current_key_count >= self.max_keys:
console_ns.abort( api.abort(
400, 400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded", code="max_keys_exceeded",
@ -856,17 +788,21 @@ class DatasetApiKeyApi(Resource):
class DatasetApiDeleteApi(Resource): class DatasetApiDeleteApi(Resource):
resource_type = "dataset" resource_type = "dataset"
@console_ns.doc("delete_dataset_api_key") @api.doc("delete_dataset_api_key")
@console_ns.doc(description="Delete dataset API key") @api.doc(description="Delete dataset API key")
@console_ns.doc(params={"api_key_id": "API key ID"}) @api.doc(params={"api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully") @api.response(204, "API key deleted successfully")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, api_key_id): def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .where(
@ -878,7 +814,7 @@ class DatasetApiDeleteApi(Resource):
) )
if key is None: if key is None:
console_ns.abort(404, message="API key not found") api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
@ -901,9 +837,9 @@ class DatasetEnableApiApi(Resource):
@console_ns.route("/datasets/api-base-info") @console_ns.route("/datasets/api-base-info")
class DatasetApiBaseUrlApi(Resource): class DatasetApiBaseUrlApi(Resource):
@console_ns.doc("get_dataset_api_base_info") @api.doc("get_dataset_api_base_info")
@console_ns.doc(description="Get dataset API base information") @api.doc(description="Get dataset API base information")
@console_ns.response(200, "API base info retrieved successfully") @api.response(200, "API base info retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -913,9 +849,9 @@ class DatasetApiBaseUrlApi(Resource):
@console_ns.route("/datasets/retrieval-setting") @console_ns.route("/datasets/retrieval-setting")
class DatasetRetrievalSettingApi(Resource): class DatasetRetrievalSettingApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting") @api.doc("get_dataset_retrieval_setting")
@console_ns.doc(description="Get dataset retrieval settings") @api.doc(description="Get dataset retrieval settings")
@console_ns.response(200, "Retrieval settings retrieved successfully") @api.response(200, "Retrieval settings retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -926,10 +862,10 @@ class DatasetRetrievalSettingApi(Resource):
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>") @console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
class DatasetRetrievalSettingMockApi(Resource): class DatasetRetrievalSettingMockApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting_mock") @api.doc("get_dataset_retrieval_setting_mock")
@console_ns.doc(description="Get mock dataset retrieval settings by vector type") @api.doc(description="Get mock dataset retrieval settings by vector type")
@console_ns.doc(params={"vector_type": "Vector store type"}) @api.doc(params={"vector_type": "Vector store type"})
@console_ns.response(200, "Mock retrieval settings retrieved successfully") @api.response(200, "Mock retrieval settings retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -939,11 +875,11 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs") @console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
class DatasetErrorDocs(Resource): class DatasetErrorDocs(Resource):
@console_ns.doc("get_dataset_error_docs") @api.doc("get_dataset_error_docs")
@console_ns.doc(description="Get dataset error documents") @api.doc(description="Get dataset error documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Error documents retrieved successfully") @api.response(200, "Error documents retrieved successfully")
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -959,12 +895,12 @@ class DatasetErrorDocs(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users") @console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
class DatasetPermissionUserListApi(Resource): class DatasetPermissionUserListApi(Resource):
@console_ns.doc("get_dataset_permission_users") @api.doc("get_dataset_permission_users")
@console_ns.doc(description="Get dataset permission user list") @api.doc(description="Get dataset permission user list")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Permission users retrieved successfully") @api.response(200, "Permission users retrieved successfully")
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -988,11 +924,11 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs") @console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
class DatasetAutoDisableLogApi(Resource): class DatasetAutoDisableLogApi(Resource):
@console_ns.doc("get_dataset_auto_disable_logs") @api.doc("get_dataset_auto_disable_logs")
@console_ns.doc(description="Get dataset auto disable logs") @api.doc(description="Get dataset auto disable logs")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Auto disable logs retrieved successfully") @api.response(200, "Auto disable logs retrieved successfully")
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -11,7 +11,7 @@ from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError, ProviderNotInitializeError,
@ -45,11 +45,9 @@ from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db from extensions.ext_database import db
from fields.dataset_fields import dataset_fields
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields, dataset_and_document_fields,
document_fields, document_fields,
document_metadata_fields,
document_status_fields, document_status_fields,
document_with_segments_fields, document_with_segments_fields,
) )
@ -63,36 +61,6 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = _get_or_create_model("Dataset", dataset_fields)
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
document_fields_copy = document_fields.copy()
document_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_model = _get_or_create_model("Document", document_fields_copy)
document_with_segments_fields_copy = document_with_segments_fields.copy()
document_with_segments_fields_copy["doc_metadata"] = fields.List(
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
)
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
class DocumentResource(Resource): class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document: def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
@ -136,10 +104,10 @@ class DocumentResource(Resource):
@console_ns.route("/datasets/process-rule") @console_ns.route("/datasets/process-rule")
class GetProcessRuleApi(Resource): class GetProcessRuleApi(Resource):
@console_ns.doc("get_process_rule") @api.doc("get_process_rule")
@console_ns.doc(description="Get dataset document processing rules") @api.doc(description="Get dataset document processing rules")
@console_ns.doc(params={"document_id": "Document ID (optional)"}) @api.doc(params={"document_id": "Document ID (optional)"})
@console_ns.response(200, "Process rules retrieved successfully") @api.response(200, "Process rules retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -184,9 +152,9 @@ class GetProcessRuleApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents") @console_ns.route("/datasets/<uuid:dataset_id>/documents")
class DatasetDocumentListApi(Resource): class DatasetDocumentListApi(Resource):
@console_ns.doc("get_dataset_documents") @api.doc("get_dataset_documents")
@console_ns.doc(description="Get documents in a dataset") @api.doc(description="Get documents in a dataset")
@console_ns.doc( @api.doc(
params={ params={
"dataset_id": "Dataset ID", "dataset_id": "Dataset ID",
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
@ -194,20 +162,19 @@ class DatasetDocumentListApi(Resource):
"keyword": "Search keyword", "keyword": "Search keyword",
"sort": "Sort order (default: -created_at)", "sort": "Sort order (default: -created_at)",
"fetch": "Fetch full details (default: false)", "fetch": "Fetch full details (default: false)",
"status": "Filter documents by display status",
} }
) )
@console_ns.response(200, "Documents retrieved successfully") @api.response(200, "Documents retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, dataset_id: str): def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str) sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
# "yes", "true", "t", "y", "1" convert to True, while others convert to False. # "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try: try:
fetch_val = request.args.get("fetch", default="false") fetch_val = request.args.get("fetch", default="false")
@ -236,9 +203,6 @@ class DatasetDocumentListApi(Resource):
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.where(Document.name.like(search))
@ -307,7 +271,7 @@ class DatasetDocumentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(dataset_and_document_model) @marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
@ -388,10 +352,10 @@ class DatasetDocumentListApi(Resource):
@console_ns.route("/datasets/init") @console_ns.route("/datasets/init")
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@console_ns.doc("init_dataset") @api.doc("init_dataset")
@console_ns.doc(description="Initialize dataset with documents") @api.doc(description="Initialize dataset with documents")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"DatasetInitRequest", "DatasetInitRequest",
{ {
"upload_file_id": fields.String(required=True, description="Upload file ID"), "upload_file_id": fields.String(required=True, description="Upload file ID"),
@ -401,12 +365,12 @@ class DatasetInitApi(Resource):
}, },
) )
) )
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model) @api.response(201, "Dataset initialized successfully", dataset_and_document_fields)
@console_ns.response(400, "Invalid request parameters") @api.response(400, "Invalid request parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(dataset_and_document_model) @marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
@ -477,12 +441,12 @@ class DatasetInitApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
class DocumentIndexingEstimateApi(DocumentResource): class DocumentIndexingEstimateApi(DocumentResource):
@console_ns.doc("estimate_document_indexing") @api.doc("estimate_document_indexing")
@console_ns.doc(description="Estimate document indexing cost") @api.doc(description="Estimate document indexing cost")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Indexing estimate calculated successfully") @api.response(200, "Indexing estimate calculated successfully")
@console_ns.response(404, "Document not found") @api.response(404, "Document not found")
@console_ns.response(400, "Document already finished") @api.response(400, "Document already finished")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -692,11 +656,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
class DocumentIndexingStatusApi(DocumentResource): class DocumentIndexingStatusApi(DocumentResource):
@console_ns.doc("get_document_indexing_status") @api.doc("get_document_indexing_status")
@console_ns.doc(description="Get document indexing status") @api.doc(description="Get document indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Indexing status retrieved successfully") @api.response(200, "Indexing status retrieved successfully")
@console_ns.response(404, "Document not found") @api.response(404, "Document not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -742,17 +706,17 @@ class DocumentIndexingStatusApi(DocumentResource):
class DocumentApi(DocumentResource): class DocumentApi(DocumentResource):
METADATA_CHOICES = {"all", "only", "without"} METADATA_CHOICES = {"all", "only", "without"}
@console_ns.doc("get_document") @api.doc("get_document")
@console_ns.doc(description="Get document details") @api.doc(description="Get document details")
@console_ns.doc( @api.doc(
params={ params={
"dataset_id": "Dataset ID", "dataset_id": "Dataset ID",
"document_id": "Document ID", "document_id": "Document ID",
"metadata": "Metadata inclusion (all/only/without)", "metadata": "Metadata inclusion (all/only/without)",
} }
) )
@console_ns.response(200, "Document retrieved successfully") @api.response(200, "Document retrieved successfully")
@console_ns.response(404, "Document not found") @api.response(404, "Document not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -863,14 +827,14 @@ class DocumentApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
class DocumentProcessingApi(DocumentResource): class DocumentProcessingApi(DocumentResource):
@console_ns.doc("update_document_processing") @api.doc("update_document_processing")
@console_ns.doc(description="Update document processing status (pause/resume)") @api.doc(description="Update document processing status (pause/resume)")
@console_ns.doc( @api.doc(
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"} params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
) )
@console_ns.response(200, "Processing status updated successfully") @api.response(200, "Processing status updated successfully")
@console_ns.response(404, "Document not found") @api.response(404, "Document not found")
@console_ns.response(400, "Invalid action") @api.response(400, "Invalid action")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -908,11 +872,11 @@ class DocumentProcessingApi(DocumentResource):
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") @console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
class DocumentMetadataApi(DocumentResource): class DocumentMetadataApi(DocumentResource):
@console_ns.doc("update_document_metadata") @api.doc("update_document_metadata")
@console_ns.doc(description="Update document metadata") @api.doc(description="Update document metadata")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"UpdateDocumentMetadataRequest", "UpdateDocumentMetadataRequest",
{ {
"doc_type": fields.String(description="Document type"), "doc_type": fields.String(description="Document type"),
@ -920,9 +884,9 @@ class DocumentMetadataApi(DocumentResource):
}, },
) )
) )
@console_ns.response(200, "Document metadata updated successfully") @api.response(200, "Document metadata updated successfully")
@console_ns.response(404, "Document not found") @api.response(404, "Document not found")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,22 +3,10 @@ from flask_restx import Resource, fields, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import ( from fields.dataset_fields import dataset_detail_fields
dataset_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
@ -26,51 +14,6 @@ from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService from services.knowledge_service import ExternalDatasetTestService
def _get_or_create_model(model_name: str, field_def):
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
def _build_dataset_detail_model():
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
tag_model = _get_or_create_model("Tag", tag_fields)
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
try:
dataset_detail_model = console_ns.models["DatasetDetail"]
except KeyError:
dataset_detail_model = _build_dataset_detail_model()
def _validate_name(name: str) -> str: def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 100: if not name or len(name) < 1 or len(name) > 100:
raise ValueError("Name must be between 1 to 100 characters.") raise ValueError("Name must be between 1 to 100 characters.")
@ -79,16 +22,16 @@ def _validate_name(name: str) -> str:
@console_ns.route("/datasets/external-knowledge-api") @console_ns.route("/datasets/external-knowledge-api")
class ExternalApiTemplateListApi(Resource): class ExternalApiTemplateListApi(Resource):
@console_ns.doc("get_external_api_templates") @api.doc("get_external_api_templates")
@console_ns.doc(description="Get external knowledge API templates") @api.doc(description="Get external knowledge API templates")
@console_ns.doc( @api.doc(
params={ params={
"page": "Page number (default: 1)", "page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)", "limit": "Number of items per page (default: 20)",
"keyword": "Search keyword", "keyword": "Search keyword",
} }
) )
@console_ns.response(200, "External API templates retrieved successfully") @api.response(200, "External API templates retrieved successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -152,11 +95,11 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>") @console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
class ExternalApiTemplateApi(Resource): class ExternalApiTemplateApi(Resource):
@console_ns.doc("get_external_api_template") @api.doc("get_external_api_template")
@console_ns.doc(description="Get external knowledge API template details") @api.doc(description="Get external knowledge API template details")
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
@console_ns.response(200, "External API template retrieved successfully") @api.response(200, "External API template retrieved successfully")
@console_ns.response(404, "Template not found") @api.response(404, "Template not found")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -220,10 +163,10 @@ class ExternalApiTemplateApi(Resource):
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check") @console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
class ExternalApiUseCheckApi(Resource): class ExternalApiUseCheckApi(Resource):
@console_ns.doc("check_external_api_usage") @api.doc("check_external_api_usage")
@console_ns.doc(description="Check if external knowledge API is being used") @api.doc(description="Check if external knowledge API is being used")
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"}) @api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
@console_ns.response(200, "Usage check completed successfully") @api.response(200, "Usage check completed successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -238,10 +181,10 @@ class ExternalApiUseCheckApi(Resource):
@console_ns.route("/datasets/external") @console_ns.route("/datasets/external")
class ExternalDatasetCreateApi(Resource): class ExternalDatasetCreateApi(Resource):
@console_ns.doc("create_external_dataset") @api.doc("create_external_dataset")
@console_ns.doc(description="Create external knowledge dataset") @api.doc(description="Create external knowledge dataset")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"CreateExternalDatasetRequest", "CreateExternalDatasetRequest",
{ {
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"), "external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
@ -251,16 +194,18 @@ class ExternalDatasetCreateApi(Resource):
}, },
) )
) )
@console_ns.response(201, "External dataset created successfully", dataset_detail_model) @api.response(201, "External dataset created successfully", dataset_detail_fields)
@console_ns.response(400, "Invalid parameters") @api.response(400, "Invalid parameters")
@console_ns.response(403, "Permission denied") @api.response(403, "Permission denied")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json") .add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
@ -296,11 +241,11 @@ class ExternalDatasetCreateApi(Resource):
@console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing")
class ExternalKnowledgeHitTestingApi(Resource): class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.doc("test_external_knowledge_retrieval") @api.doc("test_external_knowledge_retrieval")
@console_ns.doc(description="Test external knowledge retrieval for dataset") @api.doc(description="Test external knowledge retrieval for dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"ExternalHitTestingRequest", "ExternalHitTestingRequest",
{ {
"query": fields.String(required=True, description="Query text for testing"), "query": fields.String(required=True, description="Query text for testing"),
@ -309,9 +254,9 @@ class ExternalKnowledgeHitTestingApi(Resource):
}, },
) )
) )
@console_ns.response(200, "External hit testing completed successfully") @api.response(200, "External hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @api.response(400, "Invalid parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -354,10 +299,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
@console_ns.route("/test/retrieval") @console_ns.route("/test/retrieval")
class BedrockRetrievalApi(Resource): class BedrockRetrievalApi(Resource):
# this api is only for internal testing # this api is only for internal testing
@console_ns.doc("bedrock_retrieval_test") @api.doc("bedrock_retrieval_test")
@console_ns.doc(description="Bedrock retrieval test (internal use only)") @api.doc(description="Bedrock retrieval test (internal use only)")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"BedrockRetrievalTestRequest", "BedrockRetrievalTestRequest",
{ {
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"), "retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
@ -366,7 +311,7 @@ class BedrockRetrievalApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Bedrock retrieval test completed") @api.response(200, "Bedrock retrieval test completed")
def post(self): def post(self):
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
@ -12,11 +12,11 @@ from libs.login import login_required
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing") @console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase): class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval") @api.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval") @api.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"}) @api.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"HitTestingRequest", "HitTestingRequest",
{ {
"query": fields.String(required=True, description="Query text for testing"), "query": fields.String(required=True, description="Query text for testing"),
@ -26,9 +26,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
}, },
) )
) )
@console_ns.response(200, "Hit testing completed successfully") @api.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found") @api.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters") @api.response(400, "Invalid parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,7 +3,7 @@ from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -130,7 +130,7 @@ parser_datasource = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>") @console_ns.route("/auth/plugin/datasource/<path:provider_id>")
class DatasourceAuth(Resource): class DatasourceAuth(Resource):
@console_ns.expect(parser_datasource) @api.expect(parser_datasource)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -176,7 +176,7 @@ parser_datasource_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
class DatasourceAuthDeleteApi(Resource): class DatasourceAuthDeleteApi(Resource):
@console_ns.expect(parser_datasource_delete) @api.expect(parser_datasource_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -209,7 +209,7 @@ parser_datasource_update = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
class DatasourceAuthUpdateApi(Resource): class DatasourceAuthUpdateApi(Resource):
@console_ns.expect(parser_datasource_update) @api.expect(parser_datasource_update)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -267,7 +267,7 @@ parser_datasource_custom = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
class DatasourceAuthOauthCustomClient(Resource): class DatasourceAuthOauthCustomClient(Resource):
@console_ns.expect(parser_datasource_custom) @api.expect(parser_datasource_custom)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -306,7 +306,7 @@ parser_default = reqparse.RequestParser().add_argument("id", type=str, required=
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@console_ns.expect(parser_default) @api.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -334,7 +334,7 @@ parser_update_name = (
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name") @console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
class DatasourceUpdateProviderNameApi(Resource): class DatasourceUpdateProviderNameApi(Resource):
@console_ns.expect(parser_update_name) @api.expect(parser_update_name)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,10 +1,10 @@
from flask_restx import ( # type: ignore from flask_restx import ( # type: ignore
Resource, # type: ignore Resource, # type: ignore
reqparse,
) )
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -12,21 +12,17 @@ from models import Account
from models.dataset import Pipeline from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline import RagPipelineService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
class Parser(BaseModel): .add_argument("datasource_type", type=str, required=True, location="json")
inputs: dict .add_argument("credential_id", type=str, required=False, location="json")
datasource_type: str )
credential_id: str | None = None
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
class DataSourceContentPreviewApi(Resource): class DataSourceContentPreviewApi(Resource):
@console_ns.expect(console_ns.models[Parser.__name__], validate=True) @api.expect(parser)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -38,10 +34,15 @@ class DataSourceContentPreviewApi(Resource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise Forbidden() raise Forbidden()
args = Parser.model_validate(console_ns.payload) args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
inputs = args.inputs
datasource_type = args.datasource_type
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview( preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline, pipeline=pipeline,
@ -50,6 +51,6 @@ class DataSourceContentPreviewApi(Resource):
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=True, is_published=True,
credential_id=args.credential_id, credential_id=args.get("credential_id"),
) )
return preview_content, 200 return preview_content, 200

View File

@ -1,11 +1,11 @@
from flask_restx import Resource, marshal_with, reqparse # type: ignore from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import console_ns
from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
edit_permission_required,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -21,11 +21,12 @@ class RagPipelineImportApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self): def post(self):
# Check user role first # Check user role first
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
parser = ( parser = (
reqparse.RequestParser() reqparse.RequestParser()
@ -70,10 +71,12 @@ class RagPipelineImportConfirmApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_fields) @marshal_with(pipeline_import_fields)
def post(self, import_id): def post(self, import_id):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
# Check user role first
if not current_user.has_edit_permission:
raise Forbidden()
# Create service with session # Create service with session
with Session(db.engine) as session: with Session(db.engine) as session:
@ -95,9 +98,12 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@login_required @login_required
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_fields) @marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
with Session(db.engine) as session: with Session(db.engine) as session:
import_service = RagPipelineDslService(session) import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline) result = import_service.check_dependencies(pipeline=pipeline)
@ -111,9 +117,12 @@ class RagPipelineExportApi(Resource):
@login_required @login_required
@get_rag_pipeline @get_rag_pipeline
@account_initialization_required @account_initialization_required
@edit_permission_required
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
# Add include_secret params current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args") parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
args = parser.parse_args() args = parser.parse_args()

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.app.error import ( from controllers.console.app.error import (
ConversationCompletedError, ConversationCompletedError,
DraftWorkflowNotExist, DraftWorkflowNotExist,
@ -153,7 +153,7 @@ parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class RagPipelineDraftRunIterationNodeApi(Resource): class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.expect(parser_run) @api.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -187,11 +187,10 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class RagPipelineDraftRunLoopNodeApi(Resource): class RagPipelineDraftRunLoopNodeApi(Resource):
@console_ns.expect(parser_run) @api.expect(parser_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
@ -199,6 +198,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_run.parse_args() args = parser_run.parse_args()
@ -230,11 +231,10 @@ parser_draft_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
class DraftRagPipelineRunApi(Resource): class DraftRagPipelineRunApi(Resource):
@console_ns.expect(parser_draft_run) @api.expect(parser_draft_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -242,6 +242,8 @@ class DraftRagPipelineRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_draft_run.parse_args() args = parser_draft_run.parse_args()
@ -273,11 +275,10 @@ parser_published_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
class PublishedRagPipelineRunApi(Resource): class PublishedRagPipelineRunApi(Resource):
@console_ns.expect(parser_published_run) @api.expect(parser_published_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -285,6 +286,8 @@ class PublishedRagPipelineRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_published_run.parse_args() args = parser_published_run.parse_args()
@ -397,11 +400,10 @@ parser_rag_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @api.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
""" """
@ -409,6 +411,8 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_rag_run.parse_args() args = parser_rag_run.parse_args()
@ -437,10 +441,9 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@console_ns.expect(parser_rag_run) @api.expect(parser_rag_run)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str): def post(self, pipeline: Pipeline, node_id: str):
@ -449,6 +452,8 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_rag_run.parse_args() args = parser_rag_run.parse_args()
@ -482,10 +487,9 @@ parser_run_api = reqparse.RequestParser().add_argument(
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
class RagPipelineDraftNodeRunApi(Resource): class RagPipelineDraftNodeRunApi(Resource):
@console_ns.expect(parser_run_api) @api.expect(parser_run_api)
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_run_node_execution_fields) @marshal_with(workflow_run_node_execution_fields)
@ -495,6 +499,8 @@ class RagPipelineDraftNodeRunApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_run_api.parse_args() args = parser_run_api.parse_args()
@ -517,7 +523,6 @@ class RagPipelineDraftNodeRunApi(Resource):
class RagPipelineTaskStopApi(Resource): class RagPipelineTaskStopApi(Resource):
@setup_required @setup_required
@login_required @login_required
@edit_permission_required
@account_initialization_required @account_initialization_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline, task_id: str): def post(self, pipeline: Pipeline, task_id: str):
@ -526,6 +531,8 @@ class RagPipelineTaskStopApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
@ -537,7 +544,6 @@ class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@ -545,6 +551,9 @@ class PublishedRagPipelineApi(Resource):
Get published pipeline Get published pipeline
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
if not pipeline.is_published: if not pipeline.is_published:
return None return None
# fetch published workflow by pipeline # fetch published workflow by pipeline
@ -557,7 +566,6 @@ class PublishedRagPipelineApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def post(self, pipeline: Pipeline): def post(self, pipeline: Pipeline):
""" """
@ -565,6 +573,9 @@ class PublishedRagPipelineApi(Resource):
""" """
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
with Session(db.engine) as session: with Session(db.engine) as session:
pipeline = session.merge(pipeline) pipeline = session.merge(pipeline)
@ -591,12 +602,16 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
""" """
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
# Get default block configs # Get default block configs
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
return rag_pipeline_service.get_default_block_configs() return rag_pipeline_service.get_default_block_configs()
@ -607,16 +622,20 @@ parser_default = reqparse.RequestParser().add_argument("q", type=str, location="
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
class DefaultRagPipelineBlockConfigApi(Resource): class DefaultRagPipelineBlockConfigApi(Resource):
@console_ns.expect(parser_default) @api.expect(parser_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
def get(self, pipeline: Pipeline, block_type: str): def get(self, pipeline: Pipeline, block_type: str):
""" """
Get default block config Get default block config
""" """
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_default.parse_args() args = parser_default.parse_args()
q = args.get("q") q = args.get("q")
@ -644,11 +663,10 @@ parser_wf = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
class PublishedAllRagPipelineApi(Resource): class PublishedAllRagPipelineApi(Resource):
@console_ns.expect(parser_wf) @api.expect(parser_wf)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_pagination_fields) @marshal_with(workflow_pagination_fields)
def get(self, pipeline: Pipeline): def get(self, pipeline: Pipeline):
@ -656,6 +674,8 @@ class PublishedAllRagPipelineApi(Resource):
Get published workflows Get published workflows
""" """
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_wf.parse_args() args = parser_wf.parse_args()
page = args["page"] page = args["page"]
@ -696,11 +716,10 @@ parser_wf_id = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
class RagPipelineByIdApi(Resource): class RagPipelineByIdApi(Resource):
@console_ns.expect(parser_wf_id) @api.expect(parser_wf_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
@get_rag_pipeline @get_rag_pipeline
@marshal_with(workflow_fields) @marshal_with(workflow_fields)
def patch(self, pipeline: Pipeline, workflow_id: str): def patch(self, pipeline: Pipeline, workflow_id: str):
@ -709,6 +728,8 @@ class RagPipelineByIdApi(Resource):
""" """
# Check permission # Check permission
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.has_edit_permission:
raise Forbidden()
args = parser_wf_id.parse_args() args = parser_wf_id.parse_args()
@ -754,7 +775,7 @@ parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, r
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
class PublishedRagPipelineSecondStepApi(Resource): class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters) @api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -777,7 +798,7 @@ class PublishedRagPipelineSecondStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
class PublishedRagPipelineFirstStepApi(Resource): class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters) @api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -800,7 +821,7 @@ class PublishedRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
class DraftRagPipelineFirstStepApi(Resource): class DraftRagPipelineFirstStepApi(Resource):
@console_ns.expect(parser_parameters) @api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -823,7 +844,7 @@ class DraftRagPipelineFirstStepApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
class DraftRagPipelineSecondStepApi(Resource): class DraftRagPipelineSecondStepApi(Resource):
@console_ns.expect(parser_parameters) @api.expect(parser_parameters)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -854,7 +875,7 @@ parser_wf_run = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
class RagPipelineWorkflowRunListApi(Resource): class RagPipelineWorkflowRunListApi(Resource):
@console_ns.expect(parser_wf_run) @api.expect(parser_wf_run)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -975,7 +996,7 @@ parser_var = (
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
class RagPipelineDatasourceVariableApi(Resource): class RagPipelineDatasourceVariableApi(Resource):
@console_ns.expect(parser_var) @api.expect(parser_var)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1004,11 +1025,6 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
parser = reqparse.RequestParser()
parser.add_argument('type', type=str, location='args', required=False, default='all')
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type) recommended_plugins = rag_pipeline_service.get_recommended_plugins()
return recommended_plugins return recommended_plugins

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
@ -9,10 +9,10 @@ from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusA
@console_ns.route("/website/crawl") @console_ns.route("/website/crawl")
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@console_ns.doc("crawl_website") @api.doc("crawl_website")
@console_ns.doc(description="Crawl website content") @api.doc(description="Crawl website content")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"WebsiteCrawlRequest", "WebsiteCrawlRequest",
{ {
"provider": fields.String( "provider": fields.String(
@ -25,8 +25,8 @@ class WebsiteCrawlApi(Resource):
}, },
) )
) )
@console_ns.response(200, "Website crawl initiated successfully") @api.response(200, "Website crawl initiated successfully")
@console_ns.response(400, "Invalid crawl parameters") @api.response(400, "Invalid crawl parameters")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -62,12 +62,12 @@ class WebsiteCrawlApi(Resource):
@console_ns.route("/website/crawl/status/<string:job_id>") @console_ns.route("/website/crawl/status/<string:job_id>")
class WebsiteCrawlStatusApi(Resource): class WebsiteCrawlStatusApi(Resource):
@console_ns.doc("get_crawl_status") @api.doc("get_crawl_status")
@console_ns.doc(description="Get website crawl status") @api.doc(description="Get website crawl status")
@console_ns.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"}) @api.doc(params={"job_id": "Crawl job ID", "provider": "Crawl provider (firecrawl/watercrawl/jinareader)"})
@console_ns.response(200, "Crawl status retrieved successfully") @api.response(200, "Crawl status retrieved successfully")
@console_ns.response(404, "Crawl job not found") @api.response(404, "Crawl job not found")
@console_ns.response(400, "Invalid provider") @api.response(400, "Invalid provider")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -1,40 +1,44 @@
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
from models.dataset import Pipeline from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline(
view: Callable | None = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
def get_rag_pipeline(view_func: Callable[P, R]): _, current_tenant_id = current_account_with_tenant()
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
_, current_tenant_id = current_account_with_tenant() pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
pipeline_id = kwargs.get("pipeline_id") del kwargs["pipeline_id"]
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"] pipeline = (
db.session.query(Pipeline)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)
pipeline = ( if not pipeline:
db.session.query(Pipeline) raise PipelineNotFoundError()
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
)
if not pipeline: kwargs["pipeline"] = pipeline
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline return view_func(*args, **kwargs)
return view_func(*args, **kwargs) return decorated_view
return decorated_view if view is None:
return decorator
else:
return decorator(view)

View File

@ -15,6 +15,7 @@ from controllers.console.app.error import (
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -30,7 +31,6 @@ from libs.login import current_user
from models import Account from models import Account
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
from .. import console_ns from .. import console_ns
@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
class CompletionApi(InstalledAppResource): class CompletionApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( parser = (
@ -102,18 +102,12 @@ class CompletionApi(InstalledAppResource):
class CompletionStopApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource):
def post(self, installed_app, task_id): def post(self, installed_app, task_id):
app_model = installed_app.app app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -190,12 +184,6 @@ class ChatStopApi(InstalledAppResource):
if not isinstance(current_user, Account): if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance") raise ValueError("current_user must be an Account instance")
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
AppTaskService.stop_task(
task_id=task_id,
invoke_from=InvokeFrom.EXPLORE,
user_id=current_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants.languages import languages from constants.languages import languages
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from libs.helper import AppIconUrlField from libs.helper import AppIconUrlField
from libs.login import current_user, login_required from libs.login import current_user, login_required
@ -40,7 +40,7 @@ parser_apps = reqparse.RequestParser().add_argument("language", type=str, locati
@console_ns.route("/explore/apps") @console_ns.route("/explore/apps")
class RecommendedAppListApi(Resource): class RecommendedAppListApi(Resource):
@console_ns.expect(parser_apps) @api.expect(parser_apps)
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(recommended_app_list_fields) @marshal_with(recommended_app_list_fields)

View File

@ -1,7 +1,7 @@
from flask_restx import Resource, fields, marshal_with, reqparse from flask_restx import Resource, fields, marshal_with, reqparse
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,24 +9,18 @@ from models.api_based_extension import APIBasedExtension
from services.api_based_extension_service import APIBasedExtensionService from services.api_based_extension_service import APIBasedExtensionService
from services.code_based_extension_service import CodeBasedExtensionService from services.code_based_extension_service import CodeBasedExtensionService
api_based_extension_model = console_ns.model("ApiBasedExtensionModel", api_based_extension_fields)
api_based_extension_list_model = fields.List(fields.Nested(api_based_extension_model))
@console_ns.route("/code-based-extension") @console_ns.route("/code-based-extension")
class CodeBasedExtensionAPI(Resource): class CodeBasedExtensionAPI(Resource):
@console_ns.doc("get_code_based_extension") @api.doc("get_code_based_extension")
@console_ns.doc(description="Get code-based extension data by module name") @api.doc(description="Get code-based extension data by module name")
@console_ns.expect( @api.expect(
console_ns.parser().add_argument( api.parser().add_argument("module", type=str, required=True, location="args", help="Extension module name")
"module", type=str, required=True, location="args", help="Extension module name"
)
) )
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model(
"CodeBasedExtensionResponse", "CodeBasedExtensionResponse",
{"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")}, {"module": fields.String(description="Module name"), "data": fields.Raw(description="Extension data")},
), ),
@ -43,21 +37,21 @@ class CodeBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension") @console_ns.route("/api-based-extension")
class APIBasedExtensionAPI(Resource): class APIBasedExtensionAPI(Resource):
@console_ns.doc("get_api_based_extensions") @api.doc("get_api_based_extensions")
@console_ns.doc(description="Get all API-based extensions for current tenant") @api.doc(description="Get all API-based extensions for current tenant")
@console_ns.response(200, "Success", api_based_extension_list_model) @api.response(200, "Success", fields.List(fields.Nested(api_based_extension_fields)))
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_model) @marshal_with(api_based_extension_fields)
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id) return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@console_ns.doc("create_api_based_extension") @api.doc("create_api_based_extension")
@console_ns.doc(description="Create a new API-based extension") @api.doc(description="Create a new API-based extension")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"CreateAPIBasedExtensionRequest", "CreateAPIBasedExtensionRequest",
{ {
"name": fields.String(required=True, description="Extension name"), "name": fields.String(required=True, description="Extension name"),
@ -66,13 +60,13 @@ class APIBasedExtensionAPI(Resource):
}, },
) )
) )
@console_ns.response(201, "Extension created successfully", api_based_extension_model) @api.response(201, "Extension created successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_model) @marshal_with(api_based_extension_fields)
def post(self): def post(self):
args = console_ns.payload args = api.payload
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
extension_data = APIBasedExtension( extension_data = APIBasedExtension(
@ -87,25 +81,25 @@ class APIBasedExtensionAPI(Resource):
@console_ns.route("/api-based-extension/<uuid:id>") @console_ns.route("/api-based-extension/<uuid:id>")
class APIBasedExtensionDetailAPI(Resource): class APIBasedExtensionDetailAPI(Resource):
@console_ns.doc("get_api_based_extension") @api.doc("get_api_based_extension")
@console_ns.doc(description="Get API-based extension by ID") @api.doc(description="Get API-based extension by ID")
@console_ns.doc(params={"id": "Extension ID"}) @api.doc(params={"id": "Extension ID"})
@console_ns.response(200, "Success", api_based_extension_model) @api.response(200, "Success", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_model) @marshal_with(api_based_extension_fields)
def get(self, id): def get(self, id):
api_based_extension_id = str(id) api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@console_ns.doc("update_api_based_extension") @api.doc("update_api_based_extension")
@console_ns.doc(description="Update API-based extension") @api.doc(description="Update API-based extension")
@console_ns.doc(params={"id": "Extension ID"}) @api.doc(params={"id": "Extension ID"})
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"UpdateAPIBasedExtensionRequest", "UpdateAPIBasedExtensionRequest",
{ {
"name": fields.String(required=True, description="Extension name"), "name": fields.String(required=True, description="Extension name"),
@ -114,18 +108,18 @@ class APIBasedExtensionDetailAPI(Resource):
}, },
) )
) )
@console_ns.response(200, "Extension updated successfully", api_based_extension_model) @api.response(200, "Extension updated successfully", api_based_extension_fields)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(api_based_extension_model) @marshal_with(api_based_extension_fields)
def post(self, id): def post(self, id):
api_based_extension_id = str(id) api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
args = console_ns.payload args = api.payload
extension_data_from_db.name = args["name"] extension_data_from_db.name = args["name"]
extension_data_from_db.api_endpoint = args["api_endpoint"] extension_data_from_db.api_endpoint = args["api_endpoint"]
@ -135,10 +129,10 @@ class APIBasedExtensionDetailAPI(Resource):
return APIBasedExtensionService.save(extension_data_from_db) return APIBasedExtensionService.save(extension_data_from_db)
@console_ns.doc("delete_api_based_extension") @api.doc("delete_api_based_extension")
@console_ns.doc(description="Delete API-based extension") @api.doc(description="Delete API-based extension")
@console_ns.doc(params={"id": "Extension ID"}) @api.doc(params={"id": "Extension ID"})
@console_ns.response(204, "Extension deleted successfully") @api.response(204, "Extension deleted successfully")
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -3,18 +3,18 @@ from flask_restx import Resource, fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.feature_service import FeatureService from services.feature_service import FeatureService
from . import console_ns from . import api, console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required from .wraps import account_initialization_required, cloud_utm_record, setup_required
@console_ns.route("/features") @console_ns.route("/features")
class FeatureApi(Resource): class FeatureApi(Resource):
@console_ns.doc("get_tenant_features") @api.doc("get_tenant_features")
@console_ns.doc(description="Get feature configuration for current tenant") @api.doc(description="Get feature configuration for current tenant")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}), api.model("FeatureResponse", {"features": fields.Raw(description="Feature configuration object")}),
) )
@setup_required @setup_required
@login_required @login_required
@ -29,14 +29,12 @@ class FeatureApi(Resource):
@console_ns.route("/system-features") @console_ns.route("/system-features")
class SystemFeatureApi(Resource): class SystemFeatureApi(Resource):
@console_ns.doc("get_system_features") @api.doc("get_system_features")
@console_ns.doc(description="Get system-wide feature configuration") @api.doc(description="Get system-wide feature configuration")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model("SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}),
"SystemFeatureResponse", {"features": fields.Raw(description="System feature configuration object")}
),
) )
def get(self): def get(self):
"""Get system-wide feature configuration""" """Get system-wide feature configuration"""

View File

@ -11,19 +11,19 @@ from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
from . import console_ns from . import api, console_ns
from .error import AlreadySetupError, InitValidateFailedError from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@console_ns.route("/init") @console_ns.route("/init")
class InitValidateAPI(Resource): class InitValidateAPI(Resource):
@console_ns.doc("get_init_status") @api.doc("get_init_status")
@console_ns.doc(description="Get initialization validation status") @api.doc(description="Get initialization validation status")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
model=console_ns.model( model=api.model(
"InitStatusResponse", "InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])}, {"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
), ),
@ -35,20 +35,20 @@ class InitValidateAPI(Resource):
return {"status": "finished"} return {"status": "finished"}
return {"status": "not_started"} return {"status": "not_started"}
@console_ns.doc("validate_init_password") @api.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition") @api.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"InitValidateRequest", "InitValidateRequest",
{"password": fields.String(required=True, description="Initialization password", max_length=30)}, {"password": fields.String(required=True, description="Initialization password", max_length=30)},
) )
) )
@console_ns.response( @api.response(
201, 201,
"Success", "Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}), model=api.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
) )
@console_ns.response(400, "Already setup or validation failed") @api.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Validate initialization password""" """Validate initialization password"""

View File

@ -1,16 +1,16 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from . import console_ns from . import api, console_ns
@console_ns.route("/ping") @console_ns.route("/ping")
class PingApi(Resource): class PingApi(Resource):
@console_ns.doc("health_check") @api.doc("health_check")
@console_ns.doc(description="Health check endpoint for connection testing") @api.doc(description="Health check endpoint for connection testing")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}), api.model("PingResponse", {"result": fields.String(description="Health check result", example="pong")}),
) )
def get(self): def get(self):
"""Health check endpoint for connection testing""" """Health check endpoint for connection testing"""

View File

@ -10,6 +10,7 @@ from controllers.common.errors import (
RemoteFileUploadError, RemoteFileUploadError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import api
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@ -41,7 +42,7 @@ parser_upload = reqparse.RequestParser().add_argument("url", type=str, required=
@console_ns.route("/remote-files/upload") @console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource): class RemoteFileUploadApi(Resource):
@console_ns.expect(parser_upload) @api.expect(parser_upload)
@marshal_with(file_fields_with_signed_url) @marshal_with(file_fields_with_signed_url)
def post(self): def post(self):
args = parser_upload.parse_args() args = parser_upload.parse_args()

View File

@ -7,7 +7,7 @@ from libs.password import valid_password
from models.model import DifySetup, db from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
from . import console_ns from . import api, console_ns
from .error import AlreadySetupError, NotInitValidateError from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted from .wraps import only_edition_self_hosted
@ -15,12 +15,12 @@ from .wraps import only_edition_self_hosted
@console_ns.route("/setup") @console_ns.route("/setup")
class SetupApi(Resource): class SetupApi(Resource):
@console_ns.doc("get_setup_status") @api.doc("get_setup_status")
@console_ns.doc(description="Get system setup status") @api.doc(description="Get system setup status")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model(
"SetupStatusResponse", "SetupStatusResponse",
{ {
"step": fields.String(description="Setup step status", enum=["not_started", "finished"]), "step": fields.String(description="Setup step status", enum=["not_started", "finished"]),
@ -40,10 +40,10 @@ class SetupApi(Resource):
return {"step": "not_started"} return {"step": "not_started"}
return {"step": "finished"} return {"step": "finished"}
@console_ns.doc("setup_system") @api.doc("setup_system")
@console_ns.doc(description="Initialize system setup with admin account") @api.doc(description="Initialize system setup with admin account")
@console_ns.expect( @api.expect(
console_ns.model( api.model(
"SetupRequest", "SetupRequest",
{ {
"email": fields.String(required=True, description="Admin email address"), "email": fields.String(required=True, description="Admin email address"),
@ -53,10 +53,8 @@ class SetupApi(Resource):
}, },
) )
) )
@console_ns.response( @api.response(201, "Success", api.model("SetupResponse", {"result": fields.String(description="Setup result")}))
201, "Success", console_ns.model("SetupResponse", {"result": fields.String(description="Setup result")}) @api.response(400, "Already setup or validation failed")
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted @only_edition_self_hosted
def post(self): def post(self):
"""Initialize system setup with admin account""" """Initialize system setup with admin account"""

View File

@ -2,8 +2,8 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import dataset_tag_fields from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from models.model import Tag from models.model import Tag
@ -43,7 +43,7 @@ class TagListApi(Resource):
return tags, 200 return tags, 200
@console_ns.expect(parser_tags) @api.expect(parser_tags)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -68,7 +68,7 @@ parser_tag_id = reqparse.RequestParser().add_argument(
@console_ns.route("/tags/<uuid:tag_id>") @console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource): class TagUpdateDeleteApi(Resource):
@console_ns.expect(parser_tag_id) @api.expect(parser_tag_id)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -91,9 +91,12 @@ class TagUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@edit_permission_required
def delete(self, tag_id): def delete(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id = str(tag_id) tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.has_edit_permission:
raise Forbidden()
TagService.delete_tag(tag_id) TagService.delete_tag(tag_id)
@ -110,7 +113,7 @@ parser_create = (
@console_ns.route("/tag-bindings/create") @console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource): class TagBindingCreateApi(Resource):
@console_ns.expect(parser_create) @api.expect(parser_create)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -136,7 +139,7 @@ parser_remove = (
@console_ns.route("/tag-bindings/remove") @console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource): class TagBindingDeleteApi(Resource):
@console_ns.expect(parser_remove) @api.expect(parser_remove)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required

View File

@ -7,7 +7,7 @@ from packaging import version
from configs import dify_config from configs import dify_config
from . import console_ns from . import api, console_ns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,13 +18,13 @@ parser = reqparse.RequestParser().add_argument(
@console_ns.route("/version") @console_ns.route("/version")
class VersionApi(Resource): class VersionApi(Resource):
@console_ns.doc("check_version_update") @api.doc("check_version_update")
@console_ns.doc(description="Check for application version updates") @api.doc(description="Check for application version updates")
@console_ns.expect(parser) @api.expect(parser)
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model(
"VersionResponse", "VersionResponse",
{ {
"version": fields.String(description="Latest version number"), "version": fields.String(description="Latest version number"),
@ -58,7 +58,7 @@ class VersionApi(Resource):
response = httpx.get( response = httpx.get(
check_update_url, check_update_url,
params={"current_version": args["current_version"]}, params={"current_version": args["current_version"]},
timeout=httpx.Timeout(timeout=10.0, connect=3.0), timeout=httpx.Timeout(connect=3, read=10),
) )
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))

View File

@ -1,16 +1,14 @@
from datetime import datetime from datetime import datetime
from typing import Literal
import pytz import pytz
from flask import request from flask import request
from flask_restx import Resource, fields, marshal_with from flask_restx import Resource, fields, marshal_with, reqparse
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
@ -44,160 +42,20 @@ from services.account_service import AccountService
from services.billing_service import BillingService from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _init_parser():
class AccountInitPayload(BaseModel): parser = reqparse.RequestParser()
interface_language: str if dify_config.EDITION == "CLOUD":
timezone: str parser.add_argument("invitation_code", type=str, location="json")
invitation_code: str | None = None parser.add_argument("interface_language", type=supported_language, required=True, location="json").add_argument(
"timezone", type=timezone, required=True, location="json"
@field_validator("interface_language") )
@classmethod return parser
def validate_language(cls, value: str) -> str:
return supported_language(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountNamePayload(BaseModel):
name: str = Field(min_length=3, max_length=30)
class AccountAvatarPayload(BaseModel):
avatar: str
class AccountInterfaceLanguagePayload(BaseModel):
interface_language: str
@field_validator("interface_language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
class AccountInterfaceThemePayload(BaseModel):
interface_theme: Literal["light", "dark"]
class AccountTimezonePayload(BaseModel):
timezone: str
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str) -> str:
return timezone(value)
class AccountPasswordPayload(BaseModel):
password: str | None = None
new_password: str
repeat_new_password: str
@model_validator(mode="after")
def check_passwords_match(self) -> "AccountPasswordPayload":
if self.new_password != self.repeat_new_password:
raise RepeatPasswordNotMatchError()
return self
class AccountDeletePayload(BaseModel):
token: str
code: str
class AccountDeletionFeedbackPayload(BaseModel):
email: str
feedback: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class EducationActivatePayload(BaseModel):
token: str
institution: str
role: str
class EducationAutocompleteQuery(BaseModel):
keywords: str
page: int = 0
limit: int = 20
class ChangeEmailSendPayload(BaseModel):
email: str
language: str | None = None
phase: str | None = None
token: str | None = None
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailValidityPayload(BaseModel):
email: str
code: str
token: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class ChangeEmailResetPayload(BaseModel):
new_email: str
token: str
@field_validator("new_email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
class CheckEmailUniquePayload(BaseModel):
email: str
@field_validator("email")
@classmethod
def validate_email(cls, value: str) -> str:
return email(value)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(AccountInitPayload)
reg(AccountNamePayload)
reg(AccountAvatarPayload)
reg(AccountInterfaceLanguagePayload)
reg(AccountInterfaceThemePayload)
reg(AccountTimezonePayload)
reg(AccountPasswordPayload)
reg(AccountDeletePayload)
reg(AccountDeletionFeedbackPayload)
reg(EducationActivatePayload)
reg(EducationAutocompleteQuery)
reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
@console_ns.route("/account/init") @console_ns.route("/account/init")
class AccountInitApi(Resource): class AccountInitApi(Resource):
@console_ns.expect(console_ns.models[AccountInitPayload.__name__]) @api.expect(_init_parser())
@setup_required @setup_required
@login_required @login_required
def post(self): def post(self):
@ -206,18 +64,17 @@ class AccountInitApi(Resource):
if account.status == "active": if account.status == "active":
raise AccountAlreadyInitedError() raise AccountAlreadyInitedError()
payload = console_ns.payload or {} args = _init_parser().parse_args()
args = AccountInitPayload.model_validate(payload)
if dify_config.EDITION == "CLOUD": if dify_config.EDITION == "CLOUD":
if not args.invitation_code: if not args["invitation_code"]:
raise ValueError("invitation_code is required") raise ValueError("invitation_code is required")
# check invitation code # check invitation code
invitation_code = ( invitation_code = (
db.session.query(InvitationCode) db.session.query(InvitationCode)
.where( .where(
InvitationCode.code == args.invitation_code, InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused", InvitationCode.status == "unused",
) )
.first() .first()
@ -231,8 +88,8 @@ class AccountInitApi(Resource):
invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id invitation_code.used_by_account_id = account.id
account.interface_language = args.interface_language account.interface_language = args["interface_language"]
account.timezone = args.timezone account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = "active" account.status = "active"
account.initialized_at = naive_utc_now() account.initialized_at = naive_utc_now()
@ -253,104 +110,137 @@ class AccountProfileApi(Resource):
return current_user return current_user
parser_name = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/account/name") @console_ns.route("/account/name")
class AccountNameApi(Resource): class AccountNameApi(Resource):
@console_ns.expect(console_ns.models[AccountNamePayload.__name__]) @api.expect(parser_name)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_name.parse_args()
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name) # Validate account name length
if len(args["name"]) < 3 or len(args["name"]) > 30:
raise ValueError("Account name must be between 3 and 30 characters.")
updated_account = AccountService.update_account(current_user, name=args["name"])
return updated_account return updated_account
parser_avatar = reqparse.RequestParser().add_argument("avatar", type=str, required=True, location="json")
@console_ns.route("/account/avatar") @console_ns.route("/account/avatar")
class AccountAvatarApi(Resource): class AccountAvatarApi(Resource):
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__]) @api.expect(parser_avatar)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_avatar.parse_args()
args = AccountAvatarPayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, avatar=args.avatar) updated_account = AccountService.update_account(current_user, avatar=args["avatar"])
return updated_account return updated_account
parser_interface = reqparse.RequestParser().add_argument(
"interface_language", type=supported_language, required=True, location="json"
)
@console_ns.route("/account/interface-language") @console_ns.route("/account/interface-language")
class AccountInterfaceLanguageApi(Resource): class AccountInterfaceLanguageApi(Resource):
@console_ns.expect(console_ns.models[AccountInterfaceLanguagePayload.__name__]) @api.expect(parser_interface)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_interface.parse_args()
args = AccountInterfaceLanguagePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language) updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"])
return updated_account return updated_account
parser_theme = reqparse.RequestParser().add_argument(
"interface_theme", type=str, choices=["light", "dark"], required=True, location="json"
)
@console_ns.route("/account/interface-theme") @console_ns.route("/account/interface-theme")
class AccountInterfaceThemeApi(Resource): class AccountInterfaceThemeApi(Resource):
@console_ns.expect(console_ns.models[AccountInterfaceThemePayload.__name__]) @api.expect(parser_theme)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_theme.parse_args()
args = AccountInterfaceThemePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme) updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"])
return updated_account return updated_account
parser_timezone = reqparse.RequestParser().add_argument("timezone", type=str, required=True, location="json")
@console_ns.route("/account/timezone") @console_ns.route("/account/timezone")
class AccountTimezoneApi(Resource): class AccountTimezoneApi(Resource):
@console_ns.expect(console_ns.models[AccountTimezonePayload.__name__]) @api.expect(parser_timezone)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_timezone.parse_args()
args = AccountTimezonePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, timezone=args.timezone) # Validate timezone string, e.g. America/New_York, Asia/Shanghai
if args["timezone"] not in pytz.all_timezones:
raise ValueError("Invalid timezone string.")
updated_account = AccountService.update_account(current_user, timezone=args["timezone"])
return updated_account return updated_account
parser_pw = (
reqparse.RequestParser()
.add_argument("password", type=str, required=False, location="json")
.add_argument("new_password", type=str, required=True, location="json")
.add_argument("repeat_new_password", type=str, required=True, location="json")
)
@console_ns.route("/account/password") @console_ns.route("/account/password")
class AccountPasswordApi(Resource): class AccountPasswordApi(Resource):
@console_ns.expect(console_ns.models[AccountPasswordPayload.__name__]) @api.expect(parser_pw)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_pw.parse_args()
args = AccountPasswordPayload.model_validate(payload)
if args["new_password"] != args["repeat_new_password"]:
raise RepeatPasswordNotMatchError()
try: try:
AccountService.update_account_password(current_user, args.password, args.new_password) AccountService.update_account_password(current_user, args["password"], args["new_password"])
except ServiceCurrentPasswordIncorrectError: except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError() raise CurrentPasswordIncorrectError()
@ -426,19 +316,25 @@ class AccountDeleteVerifyApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
parser_delete = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
)
@console_ns.route("/account/delete") @console_ns.route("/account/delete")
class AccountDeleteApi(Resource): class AccountDeleteApi(Resource):
@console_ns.expect(console_ns.models[AccountDeletePayload.__name__]) @api.expect(parser_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_delete.parse_args()
args = AccountDeletePayload.model_validate(payload)
if not AccountService.verify_account_deletion_code(args.token, args.code): if not AccountService.verify_account_deletion_code(args["token"], args["code"]):
raise InvalidAccountDeletionCodeError() raise InvalidAccountDeletionCodeError()
AccountService.delete_account(account) AccountService.delete_account(account)
@ -446,15 +342,21 @@ class AccountDeleteApi(Resource):
return {"result": "success"} return {"result": "success"}
parser_feedback = (
reqparse.RequestParser()
.add_argument("email", type=str, required=True, location="json")
.add_argument("feedback", type=str, required=True, location="json")
)
@console_ns.route("/account/delete/feedback") @console_ns.route("/account/delete/feedback")
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@console_ns.expect(console_ns.models[AccountDeletionFeedbackPayload.__name__]) @api.expect(parser_feedback)
@setup_required @setup_required
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_feedback.parse_args()
args = AccountDeletionFeedbackPayload.model_validate(payload)
BillingService.update_account_deletion_feedback(args.email, args.feedback) BillingService.update_account_deletion_feedback(args["email"], args["feedback"])
return {"result": "success"} return {"result": "success"}
@ -477,6 +379,14 @@ class EducationVerifyApi(Resource):
return BillingService.EducationIdentity.verify(account.id, account.email) return BillingService.EducationIdentity.verify(account.id, account.email)
parser_edu = (
reqparse.RequestParser()
.add_argument("token", type=str, required=True, location="json")
.add_argument("institution", type=str, required=True, location="json")
.add_argument("role", type=str, required=True, location="json")
)
@console_ns.route("/account/education") @console_ns.route("/account/education")
class EducationApi(Resource): class EducationApi(Resource):
status_fields = { status_fields = {
@ -486,7 +396,7 @@ class EducationApi(Resource):
"allow_refresh": fields.Boolean, "allow_refresh": fields.Boolean,
} }
@console_ns.expect(console_ns.models[EducationActivatePayload.__name__]) @api.expect(parser_edu)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -495,10 +405,9 @@ class EducationApi(Resource):
def post(self): def post(self):
account, _ = current_account_with_tenant() account, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_edu.parse_args()
args = EducationActivatePayload.model_validate(payload)
return BillingService.EducationIdentity.activate(account, args.token, args.institution, args.role) return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@setup_required @setup_required
@login_required @login_required
@ -516,6 +425,14 @@ class EducationApi(Resource):
return res return res
parser_autocomplete = (
reqparse.RequestParser()
.add_argument("keywords", type=str, required=True, location="args")
.add_argument("page", type=int, required=False, location="args", default=0)
.add_argument("limit", type=int, required=False, location="args", default=20)
)
@console_ns.route("/account/education/autocomplete") @console_ns.route("/account/education/autocomplete")
class EducationAutoCompleteApi(Resource): class EducationAutoCompleteApi(Resource):
data_fields = { data_fields = {
@ -524,7 +441,7 @@ class EducationAutoCompleteApi(Resource):
"has_next": fields.Boolean, "has_next": fields.Boolean,
} }
@console_ns.expect(console_ns.models[EducationAutocompleteQuery.__name__]) @api.expect(parser_autocomplete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -532,39 +449,46 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled @cloud_edition_billing_enabled
@marshal_with(data_fields) @marshal_with(data_fields)
def get(self): def get(self):
payload = request.args.to_dict(flat=True) # type: ignore args = parser_autocomplete.parse_args()
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit) return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
parser_change_email = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("language", type=str, required=False, location="json")
.add_argument("phase", type=str, required=False, location="json")
.add_argument("token", type=str, required=False, location="json")
)
@console_ns.route("/account/change-email") @console_ns.route("/account/change-email")
class ChangeEmailSendEmailApi(Resource): class ChangeEmailSendEmailApi(Resource):
@console_ns.expect(console_ns.models[ChangeEmailSendPayload.__name__]) @api.expect(parser_change_email)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_change_email.parse_args()
args = ChangeEmailSendPayload.model_validate(payload)
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
if args.language is not None and args.language == "zh-Hans": if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
account = None account = None
user_email = args.email user_email = args["email"]
if args.phase is not None and args.phase == "new_email": if args["phase"] is not None and args["phase"] == "new_email":
if args.token is None: if args["token"] is None:
raise InvalidTokenError() raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token) reset_data = AccountService.get_change_email_data(args["token"])
if reset_data is None: if reset_data is None:
raise InvalidTokenError() raise InvalidTokenError()
user_email = reset_data.get("email", "") user_email = reset_data.get("email", "")
@ -573,103 +497,118 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidEmailError() raise InvalidEmailError()
else: else:
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
if account is None: if account is None:
raise AccountNotFound() raise AccountNotFound()
token = AccountService.send_change_email_email( token = AccountService.send_change_email_email(
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
) )
return {"result": "success", "data": token} return {"result": "success", "data": token}
parser_validity = (
reqparse.RequestParser()
.add_argument("email", type=email, required=True, location="json")
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/validity") @console_ns.route("/account/change-email/validity")
class ChangeEmailCheckApi(Resource): class ChangeEmailCheckApi(Resource):
@console_ns.expect(console_ns.models[ChangeEmailValidityPayload.__name__]) @api.expect(parser_validity)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_validity.parse_args()
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email user_email = args["email"]
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email) is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
if is_change_email_error_rate_limit: if is_change_email_error_rate_limit:
raise EmailChangeLimitError() raise EmailChangeLimitError()
token_data = AccountService.get_change_email_data(args.token) token_data = AccountService.get_change_email_data(args["token"])
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args.code != token_data.get("code"): if args["code"] != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email) AccountService.add_change_email_error_rate_limit(args["email"])
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_change_email_token(args.token) AccountService.revoke_change_email_token(args["token"])
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token( _, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={} user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
) )
AccountService.reset_change_email_error_rate_limit(args.email) AccountService.reset_change_email_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_reset = (
reqparse.RequestParser()
.add_argument("new_email", type=email, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/account/change-email/reset") @console_ns.route("/account/change-email/reset")
class ChangeEmailResetApi(Resource): class ChangeEmailResetApi(Resource):
@console_ns.expect(console_ns.models[ChangeEmailResetPayload.__name__]) @api.expect(parser_reset)
@enable_change_email @enable_change_email
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(account_fields) @marshal_with(account_fields)
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_reset.parse_args()
args = ChangeEmailResetPayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.new_email): if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError() raise AccountInFreezeError()
if not AccountService.check_email_unique(args.new_email): if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token) reset_data = AccountService.get_change_email_data(args["token"])
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token) AccountService.revoke_change_email_token(args["token"])
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args.new_email) updated_account = AccountService.update_account_email(current_user, email=args["new_email"])
AccountService.send_change_email_completed_notify_email( AccountService.send_change_email_completed_notify_email(
email=args.new_email, email=args["new_email"],
) )
return updated_account return updated_account
parser_check = reqparse.RequestParser().add_argument("email", type=email, required=True, location="json")
@console_ns.route("/account/change-email/check-email-unique") @console_ns.route("/account/change-email/check-email-unique")
class CheckEmailUnique(Resource): class CheckEmailUnique(Resource):
@console_ns.expect(console_ns.models[CheckEmailUniquePayload.__name__]) @api.expect(parser_check)
@setup_required @setup_required
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_check.parse_args()
args = CheckEmailUniquePayload.model_validate(payload) if AccountService.is_account_in_freeze(args["email"]):
if AccountService.is_account_in_freeze(args.email):
raise AccountInFreezeError() raise AccountInFreezeError()
if not AccountService.check_email_unique(args.email): if not AccountService.check_email_unique(args["email"]):
raise EmailAlreadyInUseError() raise EmailAlreadyInUseError()
return {"result": "success"} return {"result": "success"}

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields from flask_restx import Resource, fields
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -9,9 +9,9 @@ from services.agent_service import AgentService
@console_ns.route("/workspaces/current/agent-providers") @console_ns.route("/workspaces/current/agent-providers")
class AgentProviderListApi(Resource): class AgentProviderListApi(Resource):
@console_ns.doc("list_agent_providers") @api.doc("list_agent_providers")
@console_ns.doc(description="Get list of available agent providers") @api.doc(description="Get list of available agent providers")
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
fields.List(fields.Raw(description="Agent provider information")), fields.List(fields.Raw(description="Agent provider information")),
@ -31,10 +31,10 @@ class AgentProviderListApi(Resource):
@console_ns.route("/workspaces/current/agent-provider/<path:provider_name>") @console_ns.route("/workspaces/current/agent-provider/<path:provider_name>")
class AgentProviderApi(Resource): class AgentProviderApi(Resource):
@console_ns.doc("get_agent_provider") @api.doc("get_agent_provider")
@console_ns.doc(description="Get specific agent provider details") @api.doc(description="Get specific agent provider details")
@console_ns.doc(params={"provider_name": "Agent provider name"}) @api.doc(params={"provider_name": "Agent provider name"})
@console_ns.response( @api.response(
200, 200,
"Success", "Success",
fields.Raw(description="Agent provider details"), fields.Raw(description="Agent provider details"),

View File

@ -1,82 +1,62 @@
from typing import Any from flask_restx import Resource, fields, reqparse
from werkzeug.exceptions import Forbidden
from flask import request from controllers.console import api, console_ns
from flask_restx import Resource, fields from controllers.console.wraps import account_initialization_required, setup_required
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginPermissionDeniedError from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.plugin.endpoint_service import EndpointService from services.plugin.endpoint_service import EndpointService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class EndpointCreatePayload(BaseModel):
plugin_unique_identifier: str
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
class EndpointListQuery(BaseModel):
page: int = Field(ge=1)
page_size: int = Field(gt=0)
class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
@console_ns.route("/workspaces/current/endpoints/create") @console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource): class EndpointCreateApi(Resource):
@console_ns.doc("create_endpoint") @api.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint") @api.doc(description="Create a new plugin endpoint")
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__]) @api.expect(
@console_ns.response( api.model(
"EndpointCreateRequest",
{
"plugin_unique_identifier": fields.String(required=True, description="Plugin unique identifier"),
"settings": fields.Raw(required=True, description="Endpoint settings"),
"name": fields.String(required=True, description="Endpoint name"),
},
)
)
@api.response(
200, 200,
"Endpoint created successfully", "Endpoint created successfully",
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}), api.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
args = EndpointCreatePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("plugin_unique_identifier", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
plugin_unique_identifier = args["plugin_unique_identifier"]
settings = args["settings"]
name = args["name"]
try: try:
return { return {
"success": EndpointService.create_endpoint( "success": EndpointService.create_endpoint(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier, plugin_unique_identifier=plugin_unique_identifier,
name=args.name, name=name,
settings=args.settings, settings=settings,
) )
} }
except PluginPermissionDeniedError as e: except PluginPermissionDeniedError as e:
@ -85,15 +65,17 @@ class EndpointCreateApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list") @console_ns.route("/workspaces/current/endpoints/list")
class EndpointListApi(Resource): class EndpointListApi(Resource):
@console_ns.doc("list_endpoints") @api.doc("list_endpoints")
@console_ns.doc(description="List plugin endpoints with pagination") @api.doc(description="List plugin endpoints with pagination")
@console_ns.expect(console_ns.models[EndpointListQuery.__name__]) @api.expect(
@console_ns.response( api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
)
@api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model("EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}),
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
) )
@setup_required @setup_required
@login_required @login_required
@ -101,10 +83,15 @@ class EndpointListApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
args = parser.parse_args()
page = args.page page = args["page"]
page_size = args.page_size page_size = args["page_size"]
return jsonable_encoder( return jsonable_encoder(
{ {
@ -120,13 +107,18 @@ class EndpointListApi(Resource):
@console_ns.route("/workspaces/current/endpoints/list/plugin") @console_ns.route("/workspaces/current/endpoints/list/plugin")
class EndpointListForSinglePluginApi(Resource): class EndpointListForSinglePluginApi(Resource):
@console_ns.doc("list_plugin_endpoints") @api.doc("list_plugin_endpoints")
@console_ns.doc(description="List endpoints for a specific plugin") @api.doc(description="List endpoints for a specific plugin")
@console_ns.expect(console_ns.models[EndpointListForPluginQuery.__name__]) @api.expect(
@console_ns.response( api.parser()
.add_argument("page", type=int, required=True, location="args", help="Page number")
.add_argument("page_size", type=int, required=True, location="args", help="Page size")
.add_argument("plugin_id", type=str, required=True, location="args", help="Plugin ID")
)
@api.response(
200, 200,
"Success", "Success",
console_ns.model( api.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))} "PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
), ),
) )
@ -136,11 +128,17 @@ class EndpointListForSinglePluginApi(Resource):
def get(self): def get(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
.add_argument("plugin_id", type=str, required=True, location="args")
)
args = parser.parse_args()
page = args.page page = args["page"]
page_size = args.page_size page_size = args["page_size"]
plugin_id = args.plugin_id plugin_id = args["plugin_id"]
return jsonable_encoder( return jsonable_encoder(
{ {
@ -157,111 +155,147 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.route("/workspaces/current/endpoints/delete") @console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource): class EndpointDeleteApi(Resource):
@console_ns.doc("delete_endpoint") @api.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint") @api.doc(description="Delete a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @api.expect(
@console_ns.response( api.model("EndpointDeleteRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200, 200,
"Endpoint deleted successfully", "Endpoint deleted successfully",
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}), api.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload) parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
if not user.is_admin_or_owner:
raise Forbidden()
endpoint_id = args["endpoint_id"]
return { return {
"success": EndpointService.delete_endpoint( "success": EndpointService.delete_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }
@console_ns.route("/workspaces/current/endpoints/update") @console_ns.route("/workspaces/current/endpoints/update")
class EndpointUpdateApi(Resource): class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint") @api.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint") @api.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__]) @api.expect(
@console_ns.response( api.model(
"EndpointUpdateRequest",
{
"endpoint_id": fields.String(required=True, description="Endpoint ID"),
"settings": fields.Raw(required=True, description="Updated settings"),
"name": fields.String(required=True, description="Updated name"),
},
)
)
@api.response(
200, 200,
"Endpoint updated successfully", "Endpoint updated successfully",
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}), api.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload) parser = (
reqparse.RequestParser()
.add_argument("endpoint_id", type=str, required=True)
.add_argument("settings", type=dict, required=True)
.add_argument("name", type=str, required=True)
)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
settings = args["settings"]
name = args["name"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.update_endpoint( "success": EndpointService.update_endpoint(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user.id, user_id=user.id,
endpoint_id=args.endpoint_id, endpoint_id=endpoint_id,
name=args.name, name=name,
settings=args.settings, settings=settings,
) )
} }
@console_ns.route("/workspaces/current/endpoints/enable") @console_ns.route("/workspaces/current/endpoints/enable")
class EndpointEnableApi(Resource): class EndpointEnableApi(Resource):
@console_ns.doc("enable_endpoint") @api.doc("enable_endpoint")
@console_ns.doc(description="Enable a plugin endpoint") @api.doc(description="Enable a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @api.expect(
@console_ns.response( api.model("EndpointEnableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200, 200,
"Endpoint enabled successfully", "Endpoint enabled successfully",
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}), api.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload) parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.enable_endpoint( "success": EndpointService.enable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }
@console_ns.route("/workspaces/current/endpoints/disable") @console_ns.route("/workspaces/current/endpoints/disable")
class EndpointDisableApi(Resource): class EndpointDisableApi(Resource):
@console_ns.doc("disable_endpoint") @api.doc("disable_endpoint")
@console_ns.doc(description="Disable a plugin endpoint") @api.doc(description="Disable a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__]) @api.expect(
@console_ns.response( api.model("EndpointDisableRequest", {"endpoint_id": fields.String(required=True, description="Endpoint ID")})
)
@api.response(
200, 200,
"Endpoint disabled successfully", "Endpoint disabled successfully",
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}), api.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
) )
@console_ns.response(403, "Admin privileges required") @api.response(403, "Admin privileges required")
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload) parser = reqparse.RequestParser().add_argument("endpoint_id", type=str, required=True)
args = parser.parse_args()
endpoint_id = args["endpoint_id"]
if not user.is_admin_or_owner:
raise Forbidden()
return { return {
"success": EndpointService.disable_endpoint( "success": EndpointService.disable_endpoint(tenant_id=tenant_id, user_id=user.id, endpoint_id=endpoint_id)
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
} }

View File

@ -1,12 +1,11 @@
from urllib import parse from urllib import parse
from flask import abort, request from flask import abort, request
from flask_restx import Resource, marshal_with from flask_restx import Resource, marshal_with, reqparse
from pydantic import BaseModel, Field
import services import services
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.auth.error import ( from controllers.console.auth.error import (
CannotTransferOwnerToSelfError, CannotTransferOwnerToSelfError,
EmailCodeError, EmailCodeError,
@ -32,42 +31,6 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.errors.account import AccountAlreadyInTenantError from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class MemberInvitePayload(BaseModel):
emails: list[str] = Field(default_factory=list)
role: TenantAccountRole
language: str | None = None
class MemberRoleUpdatePayload(BaseModel):
role: str
class OwnerTransferEmailPayload(BaseModel):
language: str | None = None
class OwnerTransferCheckPayload(BaseModel):
code: str
token: str
class OwnerTransferPayload(BaseModel):
token: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(MemberInvitePayload)
reg(MemberRoleUpdatePayload)
reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
@console_ns.route("/workspaces/current/members") @console_ns.route("/workspaces/current/members")
class MemberListApi(Resource): class MemberListApi(Resource):
@ -85,22 +48,29 @@ class MemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
parser_invite = (
reqparse.RequestParser()
.add_argument("emails", type=list, required=True, location="json")
.add_argument("role", type=str, required=True, default="admin", location="json")
.add_argument("language", type=str, required=False, location="json")
)
@console_ns.route("/workspaces/current/members/invite-email") @console_ns.route("/workspaces/current/members/invite-email")
class MemberInviteEmailApi(Resource): class MemberInviteEmailApi(Resource):
"""Invite a new member by email.""" """Invite a new member by email."""
@console_ns.expect(console_ns.models[MemberInvitePayload.__name__]) @api.expect(parser_invite)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("members") @cloud_edition_billing_resource_check("members")
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_invite.parse_args()
args = MemberInvitePayload.model_validate(payload)
invitee_emails = args.emails invitee_emails = args["emails"]
invitee_role = args.role invitee_role = args["role"]
interface_language = args.language interface_language = args["language"]
if not TenantAccountRole.is_non_owner_role(invitee_role): if not TenantAccountRole.is_non_owner_role(invitee_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@ -176,18 +146,20 @@ class MemberCancelInviteApi(Resource):
}, 200 }, 200
parser_update = reqparse.RequestParser().add_argument("role", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role") @console_ns.route("/workspaces/current/members/<uuid:member_id>/update-role")
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):
"""Update member role.""" """Update member role."""
@console_ns.expect(console_ns.models[MemberRoleUpdatePayload.__name__]) @api.expect(parser_update)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def put(self, member_id): def put(self, member_id):
payload = console_ns.payload or {} args = parser_update.parse_args()
args = MemberRoleUpdatePayload.model_validate(payload) new_role = args["role"]
new_role = args.role
if not TenantAccountRole.is_valid_role(new_role): if not TenantAccountRole.is_valid_role(new_role):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
@ -225,18 +197,20 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200 return {"result": "success", "accounts": members}, 200
parser_send = reqparse.RequestParser().add_argument("language", type=str, required=False, location="json")
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email") @console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")
class SendOwnerTransferEmailApi(Resource): class SendOwnerTransferEmailApi(Resource):
"""Send owner transfer email.""" """Send owner transfer email."""
@console_ns.expect(console_ns.models[OwnerTransferEmailPayload.__name__]) @api.expect(parser_send)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_send.parse_args()
args = OwnerTransferEmailPayload.model_validate(payload)
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address): if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError() raise EmailSendIpLimitError()
@ -247,7 +221,7 @@ class SendOwnerTransferEmailApi(Resource):
if not TenantService.is_owner(current_user, current_user.current_tenant): if not TenantService.is_owner(current_user, current_user.current_tenant):
raise NotOwnerError() raise NotOwnerError()
if args.language is not None and args.language == "zh-Hans": if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans" language = "zh-Hans"
else: else:
language = "en-US" language = "en-US"
@ -264,16 +238,22 @@ class SendOwnerTransferEmailApi(Resource):
return {"result": "success", "data": token} return {"result": "success", "data": token}
parser_owner = (
reqparse.RequestParser()
.add_argument("code", type=str, required=True, location="json")
.add_argument("token", type=str, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/members/owner-transfer-check") @console_ns.route("/workspaces/current/members/owner-transfer-check")
class OwnerTransferCheckApi(Resource): class OwnerTransferCheckApi(Resource):
@console_ns.expect(console_ns.models[OwnerTransferCheckPayload.__name__]) @api.expect(parser_owner)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self): def post(self):
payload = console_ns.payload or {} args = parser_owner.parse_args()
args = OwnerTransferCheckPayload.model_validate(payload)
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
if not current_user.current_tenant: if not current_user.current_tenant:
@ -287,37 +267,41 @@ class OwnerTransferCheckApi(Resource):
if is_owner_transfer_error_rate_limit: if is_owner_transfer_error_rate_limit:
raise OwnerTransferLimitError() raise OwnerTransferLimitError()
token_data = AccountService.get_owner_transfer_data(args.token) token_data = AccountService.get_owner_transfer_data(args["token"])
if token_data is None: if token_data is None:
raise InvalidTokenError() raise InvalidTokenError()
if user_email != token_data.get("email"): if user_email != token_data.get("email"):
raise InvalidEmailError() raise InvalidEmailError()
if args.code != token_data.get("code"): if args["code"] != token_data.get("code"):
AccountService.add_owner_transfer_error_rate_limit(user_email) AccountService.add_owner_transfer_error_rate_limit(user_email)
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token # Verified, revoke the first token
AccountService.revoke_owner_transfer_token(args.token) AccountService.revoke_owner_transfer_token(args["token"])
# Refresh token data by generating a new token # Refresh token data by generating a new token
_, new_token = AccountService.generate_owner_transfer_token(user_email, code=args.code, additional_data={}) _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
AccountService.reset_owner_transfer_error_rate_limit(user_email) AccountService.reset_owner_transfer_error_rate_limit(user_email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
parser_owner_transfer = reqparse.RequestParser().add_argument(
"token", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer") @console_ns.route("/workspaces/current/members/<uuid:member_id>/owner-transfer")
class OwnerTransfer(Resource): class OwnerTransfer(Resource):
@console_ns.expect(console_ns.models[OwnerTransferPayload.__name__]) @api.expect(parser_owner_transfer)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@is_allow_transfer_owner @is_allow_transfer_owner
def post(self, member_id): def post(self, member_id):
payload = console_ns.payload or {} args = parser_owner_transfer.parse_args()
args = OwnerTransferPayload.model_validate(payload)
# check if the current user is the owner of the workspace # check if the current user is the owner of the workspace
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
@ -329,14 +313,14 @@ class OwnerTransfer(Resource):
if current_user.id == str(member_id): if current_user.id == str(member_id):
raise CannotTransferOwnerToSelfError() raise CannotTransferOwnerToSelfError()
transfer_token_data = AccountService.get_owner_transfer_data(args.token) transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
if not transfer_token_data: if not transfer_token_data:
raise InvalidTokenError() raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email: if transfer_token_data.get("email") != current_user.email:
raise InvalidEmailError() raise InvalidEmailError()
AccountService.revoke_owner_transfer_token(args.token) AccountService.revoke_owner_transfer_token(args["token"])
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if not member:

View File

@ -1,97 +1,32 @@
import io import io
from typing import Any, Literal
from flask import request, send_file from flask import send_file
from flask_restx import Resource from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value from libs.helper import StrLen, uuid_value
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" parser_model = reqparse.RequestParser().add_argument(
"model_type",
type=str,
class ParserModelList(BaseModel): required=False,
model_type: ModelType | None = None nullable=True,
choices=[mt.value for mt in ModelType],
location="args",
class ParserCredentialId(BaseModel): )
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_optional_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialCreate(BaseModel):
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
class ParserCredentialUpdate(BaseModel):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialDelete(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialSwitch(BaseModel):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_switch_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserCredentialValidate(BaseModel):
credentials: dict[str, Any]
class ParserPreferredProviderType(BaseModel):
preferred_provider_type: Literal["system", "custom"]
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ParserModelList)
reg(ParserCredentialId)
reg(ParserCredentialCreate)
reg(ParserCredentialUpdate)
reg(ParserCredentialDelete)
reg(ParserCredentialSwitch)
reg(ParserCredentialValidate)
reg(ParserPreferredProviderType)
@console_ns.route("/workspaces/current/model-providers") @console_ns.route("/workspaces/current/model-providers")
class ModelProviderListApi(Resource): class ModelProviderListApi(Resource):
@console_ns.expect(console_ns.models[ParserModelList.__name__]) @api.expect(parser_model)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -99,18 +34,38 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id tenant_id = current_tenant_id
payload = request.args.to_dict(flat=True) # type: ignore args = parser_model.parse_args()
args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.model_type) provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
return jsonable_encoder({"data": provider_list}) return jsonable_encoder({"data": provider_list})
parser_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=False, nullable=True, location="args"
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = reqparse.RequestParser().add_argument(
"credential_id", type=uuid_value, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials")
class ModelProviderCredentialApi(Resource): class ModelProviderCredentialApi(Resource):
@console_ns.expect(console_ns.models[ParserCredentialId.__name__]) @api.expect(parser_cred)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -118,25 +73,25 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential # if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True) # type: ignore args = parser_cred.parse_args()
args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential( credentials = model_provider_service.get_provider_credential(
tenant_id=tenant_id, provider=provider, credential_id=args.credential_id tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
) )
return {"credentials": credentials} return {"credentials": credentials}
@console_ns.expect(console_ns.models[ParserCredentialCreate.__name__]) @api.expect(parser_post_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} if not current_user.is_admin_or_owner:
args = ParserCredentialCreate.model_validate(payload) raise Forbidden()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -144,24 +99,24 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.create_provider_credential( model_provider_service.create_provider_credential(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args.credentials, credentials=args["credentials"],
credential_name=args.name, credential_name=args["name"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
return {"result": "success"}, 201 return {"result": "success"}, 201
@console_ns.expect(console_ns.models[ParserCredentialUpdate.__name__]) @api.expect(parser_put_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
payload = console_ns.payload or {} args = parser_put_cred.parse_args()
args = ParserCredentialUpdate.model_validate(payload)
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -169,64 +124,74 @@ class ModelProviderCredentialApi(Resource):
model_provider_service.update_provider_credential( model_provider_service.update_provider_credential(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credentials=args.credentials, credentials=args["credentials"],
credential_id=args.credential_id, credential_id=args["credential_id"],
credential_name=args.name, credential_name=args["name"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
return {"result": "success"} return {"result": "success"}
@console_ns.expect(console_ns.models[ParserCredentialDelete.__name__]) @api.expect(parser_delete_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} if not current_user.is_admin_or_owner:
args = ParserCredentialDelete.model_validate(payload) raise Forbidden()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential( model_provider_service.remove_provider_credential(
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id tenant_id=current_tenant_id, provider=provider, credential_id=args["credential_id"]
) )
return {"result": "success"}, 204 return {"result": "success"}, 204
parser_switch = reqparse.RequestParser().add_argument(
"credential_id", type=str, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
class ModelProviderCredentialSwitchApi(Resource): class ModelProviderCredentialSwitchApi(Resource):
@console_ns.expect(console_ns.models[ParserCredentialSwitch.__name__]) @api.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} if not current_user.is_admin_or_owner:
args = ParserCredentialSwitch.model_validate(payload) raise Forbidden()
args = parser_switch.parse_args()
service = ModelProviderService() service = ModelProviderService()
service.switch_active_provider_credential( service.switch_active_provider_credential(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
credential_id=args.credential_id, credential_id=args["credential_id"],
) )
return {"result": "success"} return {"result": "success"}
parser_validate = reqparse.RequestParser().add_argument(
"credentials", type=dict, required=True, nullable=False, location="json"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate") @console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/validate")
class ModelProviderValidateApi(Resource): class ModelProviderValidateApi(Resource):
@console_ns.expect(console_ns.models[ParserCredentialValidate.__name__]) @api.expect(parser_validate)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_validate.parse_args()
args = ParserCredentialValidate.model_validate(payload)
tenant_id = current_tenant_id tenant_id = current_tenant_id
@ -237,7 +202,7 @@ class ModelProviderValidateApi(Resource):
try: try:
model_provider_service.validate_provider_credentials( model_provider_service.validate_provider_credentials(
tenant_id=tenant_id, provider=provider, credentials=args.credentials tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
result = False result = False
@ -270,24 +235,34 @@ class ModelProviderIconApi(Resource):
return send_file(io.BytesIO(icon), mimetype=mimetype) return send_file(io.BytesIO(icon), mimetype=mimetype)
parser_preferred = reqparse.RequestParser().add_argument(
"preferred_provider_type",
type=str,
required=True,
nullable=False,
choices=["system", "custom"],
location="json",
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type") @console_ns.route("/workspaces/current/model-providers/<path:provider>/preferred-provider-type")
class PreferredProviderTypeUpdateApi(Resource): class PreferredProviderTypeUpdateApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferredProviderType.__name__]) @api.expect(parser_preferred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_tenant_id tenant_id = current_tenant_id
payload = console_ns.payload or {} args = parser_preferred.parse_args()
args = ParserPreferredProviderType.model_validate(payload)
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.switch_preferred_provider( model_provider_service.switch_preferred_provider(
tenant_id=tenant_id, provider=provider, preferred_provider_type=args.preferred_provider_type tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"]
) )
return {"result": "success"} return {"result": "success"}

View File

@ -1,176 +1,122 @@
import logging import logging
from typing import Any, cast
from flask import request from flask_restx import Resource, reqparse
from flask_restx import Resource from werkzeug.exceptions import Forbidden
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import uuid_value from libs.helper import StrLen, uuid_value
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
from services.model_load_balancing_service import ModelLoadBalancingService from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService from services.model_provider_service import ModelProviderService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ParserGetDefault(BaseModel): parser_get_default = reqparse.RequestParser().add_argument(
model_type: ModelType "model_type",
type=str,
required=True,
class ParserPostDefault(BaseModel): nullable=False,
class Inner(BaseModel): choices=[mt.value for mt in ModelType],
model_type: ModelType location="args",
model: str | None = None )
provider: str | None = None parser_post_default = reqparse.RequestParser().add_argument(
"model_settings", type=list, required=True, nullable=False, location="json"
model_settings: list[Inner] )
class ParserDeleteModels(BaseModel):
model: str
model_type: ModelType
class LoadBalancingPayload(BaseModel):
configs: list[dict[str, Any]] | None = None
enabled: bool | None = None
class ParserPostModels(BaseModel):
model: str
model_type: ModelType
load_balancing: LoadBalancingPayload | None = None
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserGetCredentials(BaseModel):
model: str
model_type: ModelType
config_from: str | None = None
credential_id: str | None = None
@field_validator("credential_id")
@classmethod
def validate_get_credential_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class ParserCredentialBase(BaseModel):
model: str
model_type: ModelType
class ParserCreateCredential(ParserCredentialBase):
name: str | None = Field(default=None, max_length=30)
credentials: dict[str, Any]
class ParserUpdateCredential(ParserCredentialBase):
credential_id: str
credentials: dict[str, Any]
name: str | None = Field(default=None, max_length=30)
@field_validator("credential_id")
@classmethod
def validate_update_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserDeleteCredential(ParserCredentialBase):
credential_id: str
@field_validator("credential_id")
@classmethod
def validate_delete_credential_id(cls, value: str) -> str:
return uuid_value(value)
class ParserParameter(BaseModel):
model: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ParserGetDefault)
reg(ParserPostDefault)
reg(ParserDeleteModels)
reg(ParserPostModels)
reg(ParserGetCredentials)
reg(ParserCreateCredential)
reg(ParserUpdateCredential)
reg(ParserDeleteCredential)
reg(ParserParameter)
@console_ns.route("/workspaces/current/default-model") @console_ns.route("/workspaces/current/default-model")
class DefaultModelApi(Resource): class DefaultModelApi(Resource):
@console_ns.expect(console_ns.models[ParserGetDefault.__name__]) @api.expect(parser_get_default)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserGetDefault.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_get_default.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
default_model_entity = model_provider_service.get_default_model_of_model_type( default_model_entity = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id, model_type=args.model_type tenant_id=tenant_id, model_type=args["model_type"]
) )
return jsonable_encoder({"data": default_model_entity}) return jsonable_encoder({"data": default_model_entity})
@console_ns.expect(console_ns.models[ParserPostDefault.__name__]) @api.expect(parser_post_default)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
args = ParserPostDefault.model_validate(console_ns.payload) if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_default.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_settings = args.model_settings model_settings = args["model_settings"]
for model_setting in model_settings: for model_setting in model_settings:
if model_setting.provider is None: if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
raise ValueError("invalid model type")
if "provider" not in model_setting:
continue continue
if "model" not in model_setting:
raise ValueError("invalid model")
try: try:
model_provider_service.update_default_model_of_model_type( model_provider_service.update_default_model_of_model_type(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=model_setting.model_type, model_type=model_setting["model_type"],
provider=model_setting.provider, provider=model_setting["provider"],
model=cast(str, model_setting.model), model=model_setting["model"],
) )
except Exception as ex: except Exception as ex:
logger.exception( logger.exception(
"Failed to update default model, model type: %s, model: %s", "Failed to update default model, model type: %s, model: %s",
model_setting.model_type, model_setting["model_type"],
model_setting.model, model_setting.get("model"),
) )
raise ex raise ex
return {"result": "success"} return {"result": "success"}
parser_post_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
.add_argument("config_from", type=str, required=False, nullable=True, location="json")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
)
parser_delete_models = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models")
class ModelProviderModelApi(Resource): class ModelProviderModelApi(Resource):
@setup_required @setup_required
@ -184,107 +130,171 @@ class ModelProviderModelApi(Resource):
return jsonable_encoder({"data": models}) return jsonable_encoder({"data": models})
@console_ns.expect(console_ns.models[ParserPostModels.__name__]) @api.expect(parser_post_models)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
# To save the model's load balance configs # To save the model's load balance configs
_, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
args = ParserPostModels.model_validate(console_ns.payload)
if args.config_from == "custom-model": if not current_user.is_admin_or_owner:
if not args.credential_id: raise Forbidden()
args = parser_post_models.parse_args()
if args.get("config_from", "") == "custom-model":
if not args.get("credential_id"):
raise ValueError("credential_id is required when configuring a custom-model") raise ValueError("credential_id is required when configuring a custom-model")
service = ModelProviderService() service = ModelProviderService()
service.switch_active_custom_model_credential( service.switch_active_custom_model_credential(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model_type=args.model_type, model_type=args["model_type"],
model=args.model, model=args["model"],
credential_id=args.credential_id, credential_id=args["credential_id"],
) )
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
if args.load_balancing and args.load_balancing.configs: if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
# save load balancing configs # save load balancing configs
model_load_balancing_service.update_load_balancing_configs( model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args.model, model=args["model"],
model_type=args.model_type, model_type=args["model_type"],
configs=args.load_balancing.configs, configs=args["load_balancing"]["configs"],
config_from=args.config_from or "", config_from=args.get("config_from", ""),
) )
if args.load_balancing.enabled: if args.get("load_balancing", {}).get("enabled"):
model_load_balancing_service.enable_model_load_balancing( model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
) )
else: else:
model_load_balancing_service.disable_model_load_balancing( model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__], validate=True) @api.expect(parser_delete_models)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
_, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload) if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_delete_models.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_model( model_provider_service.remove_model(
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
) )
return {"result": "success"}, 204 return {"result": "success"}, 204
parser_get_credentials = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="args")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="args",
)
.add_argument("config_from", type=str, required=False, nullable=True, location="args")
.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
)
parser_post_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
)
parser_put_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
)
parser_delete_cred = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
class ModelProviderModelCredentialApi(Resource): class ModelProviderModelCredentialApi(Resource):
@console_ns.expect(console_ns.models[ParserGetCredentials.__name__]) @api.expect(parser_get_credentials)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserGetCredentials.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_get_credentials.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
current_credential = model_provider_service.get_model_credential( current_credential = model_provider_service.get_model_credential(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model_type=args.model_type, model_type=args["model_type"],
model=args.model, model=args["model"],
credential_id=args.credential_id, credential_id=args.get("credential_id"),
) )
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args.model, model=args["model"],
model_type=args.model_type, model_type=args["model_type"],
config_from=args.config_from or "", config_from=args.get("config_from", ""),
) )
if args.config_from == "predefined-model": if args.get("config_from", "") == "predefined-model":
available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
tenant_id=tenant_id, provider_name=provider tenant_id=tenant_id, provider_name=provider
) )
else: else:
model_type = args.model_type model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args.model tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
) )
return jsonable_encoder( return jsonable_encoder(
@ -301,15 +311,17 @@ class ModelProviderModelCredentialApi(Resource):
} }
) )
@console_ns.expect(console_ns.models[ParserCreateCredential.__name__]) @api.expect(parser_post_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
args = ParserCreateCredential.model_validate(console_ns.payload) if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_post_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -317,30 +329,33 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.create_model_credential( model_provider_service.create_model_credential(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args.model, model=args["model"],
model_type=args.model_type, model_type=args["model_type"],
credentials=args.credentials, credentials=args["credentials"],
credential_name=args.name, credential_name=args["name"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
logger.exception( logger.exception(
"Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
tenant_id, tenant_id,
args.model, args.get("model"),
args.model_type, args.get("model_type"),
) )
raise ValueError(str(ex)) raise ValueError(str(ex))
return {"result": "success"}, 201 return {"result": "success"}, 201
@console_ns.expect(console_ns.models[ParserUpdateCredential.__name__]) @api.expect(parser_put_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def put(self, provider: str): def put(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
args = ParserUpdateCredential.model_validate(console_ns.payload)
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_put_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -348,87 +363,109 @@ class ModelProviderModelCredentialApi(Resource):
model_provider_service.update_model_credential( model_provider_service.update_model_credential(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args.model_type, model_type=args["model_type"],
model=args.model, model=args["model"],
credentials=args.credentials, credentials=args["credentials"],
credential_id=args.credential_id, credential_id=args["credential_id"],
credential_name=args.name, credential_name=args["name"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
return {"result": "success"} return {"result": "success"}
@console_ns.expect(console_ns.models[ParserDeleteCredential.__name__]) @api.expect(parser_delete_cred)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str): def delete(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
args = ParserDeleteCredential.model_validate(console_ns.payload)
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_delete_cred.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.remove_model_credential( model_provider_service.remove_model_credential(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args.model_type, model_type=args["model_type"],
model=args.model, model=args["model"],
credential_id=args.credential_id, credential_id=args["credential_id"],
) )
return {"result": "success"}, 204 return {"result": "success"}, 204
class ParserSwitch(BaseModel): parser_switch = (
model: str reqparse.RequestParser()
model_type: ModelType .add_argument("model", type=str, required=True, nullable=False, location="json")
credential_id: str .add_argument(
"model_type",
type=str,
console_ns.schema_model( required=True,
ParserSwitch.__name__, ParserSwitch.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
) )
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
class ModelProviderModelCredentialSwitchApi(Resource): class ModelProviderModelCredentialSwitchApi(Resource):
@console_ns.expect(console_ns.models[ParserSwitch.__name__]) @api.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
args = ParserSwitch.model_validate(console_ns.payload)
if not current_user.is_admin_or_owner:
raise Forbidden()
args = parser_switch.parse_args()
service = ModelProviderService() service = ModelProviderService()
service.add_model_credential_to_model_list( service.add_model_credential_to_model_list(
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
provider=provider, provider=provider,
model_type=args.model_type, model_type=args["model_type"],
model=args.model, model=args["model"],
credential_id=args.credential_id, credential_id=args["credential_id"],
) )
return {"result": "success"} return {"result": "success"}
parser_model_enable_disable = (
reqparse.RequestParser()
.add_argument("model", type=str, required=True, nullable=False, location="json")
.add_argument(
"model_type",
type=str,
required=True,
nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
)
@console_ns.route( @console_ns.route(
"/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable" "/workspaces/current/model-providers/<path:provider>/models/enable", endpoint="model-provider-model-enable"
) )
class ModelProviderModelEnableApi(Resource): class ModelProviderModelEnableApi(Resource):
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @api.expect(parser_model_enable_disable)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload) args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.enable_model( model_provider_service.enable_model(
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
) )
return {"result": "success"} return {"result": "success"}
@ -438,43 +475,48 @@ class ModelProviderModelEnableApi(Resource):
"/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable" "/workspaces/current/model-providers/<path:provider>/models/disable", endpoint="model-provider-model-disable"
) )
class ModelProviderModelDisableApi(Resource): class ModelProviderModelDisableApi(Resource):
@console_ns.expect(console_ns.models[ParserDeleteModels.__name__]) @api.expect(parser_model_enable_disable)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str): def patch(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserDeleteModels.model_validate(console_ns.payload) args = parser_model_enable_disable.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
model_provider_service.disable_model( model_provider_service.disable_model(
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
) )
return {"result": "success"} return {"result": "success"}
class ParserValidate(BaseModel): parser_validate = (
model: str reqparse.RequestParser()
model_type: ModelType .add_argument("model", type=str, required=True, nullable=False, location="json")
credentials: dict .add_argument(
"model_type",
type=str,
console_ns.schema_model( required=True,
ParserValidate.__name__, ParserValidate.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) nullable=False,
choices=[mt.value for mt in ModelType],
location="json",
)
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
) )
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/validate")
class ModelProviderModelValidateApi(Resource): class ModelProviderModelValidateApi(Resource):
@console_ns.expect(console_ns.models[ParserValidate.__name__]) @api.expect(parser_validate)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserValidate.model_validate(console_ns.payload)
args = parser_validate.parse_args()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
@ -485,9 +527,9 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service.validate_model_credentials( model_provider_service.validate_model_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
model=args.model, model=args["model"],
model_type=args.model_type, model_type=args["model_type"],
credentials=args.credentials, credentials=args["credentials"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
result = False result = False
@ -501,19 +543,24 @@ class ModelProviderModelValidateApi(Resource):
return response return response
parser_parameter = reqparse.RequestParser().add_argument(
"model", type=str, required=True, nullable=False, location="args"
)
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules") @console_ns.route("/workspaces/current/model-providers/<path:provider>/models/parameter-rules")
class ModelProviderModelParameterRuleApi(Resource): class ModelProviderModelParameterRuleApi(Resource):
@console_ns.expect(console_ns.models[ParserParameter.__name__]) @api.expect(parser_parameter)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider: str):
args = ParserParameter.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_parameter.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
parameter_rules = model_provider_service.get_model_parameter_rules( parameter_rules = model_provider_service.get_model_parameter_rules(
tenant_id=tenant_id, provider=provider, model=args.model tenant_id=tenant_id, provider=provider, model=args["model"]
) )
return jsonable_encoder({"data": parameter_rules}) return jsonable_encoder({"data": parameter_rules})

View File

@ -1,15 +1,13 @@
import io import io
from typing import Literal
from flask import request, send_file from flask import request, send_file
from flask_restx import Resource from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.workspace import plugin_permission_required from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import current_account_with_tenant, login_required from libs.login import current_account_with_tenant, login_required
@ -19,12 +17,6 @@ from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/workspaces/current/plugin/debugging-key") @console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource): class PluginDebuggingKeyApi(Resource):
@ -45,194 +37,88 @@ class PluginDebuggingKeyApi(Resource):
raise ValueError(e) raise ValueError(e)
class ParserList(BaseModel): parser_list = (
page: int = Field(default=1) reqparse.RequestParser()
page_size: int = Field(default=256) .add_argument("page", type=int, required=False, location="args", default=1)
.add_argument("page_size", type=int, required=False, location="args", default=256)
)
reg(ParserList)
@console_ns.route("/workspaces/current/plugin/list") @console_ns.route("/workspaces/current/plugin/list")
class PluginListApi(Resource): class PluginListApi(Resource):
@console_ns.expect(console_ns.models[ParserList.__name__]) @api.expect(parser_list)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserList.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_list.parse_args()
try: try:
plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total})
class ParserLatest(BaseModel): parser_latest = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
plugin_ids: list[str]
class ParserIcon(BaseModel):
tenant_id: str
filename: str
class ParserAsset(BaseModel):
plugin_unique_identifier: str
file_name: str
class ParserGithubUpload(BaseModel):
repo: str
version: str
package: str
class ParserPluginIdentifiers(BaseModel):
plugin_unique_identifiers: list[str]
class ParserGithubInstall(BaseModel):
plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserPluginIdentifierQuery(BaseModel):
plugin_unique_identifier: str
class ParserTasks(BaseModel):
page: int
page_size: int
class ParserMarketplaceUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
class ParserGithubUpgrade(BaseModel):
original_plugin_unique_identifier: str
new_plugin_unique_identifier: str
repo: str
version: str
package: str
class ParserUninstall(BaseModel):
plugin_installation_id: str
class ParserPermissionChange(BaseModel):
install_permission: TenantPluginPermission.InstallPermission
debug_permission: TenantPluginPermission.DebugPermission
class ParserDynamicOptions(BaseModel):
plugin_id: str
provider: str
action: str
parameter: str
credential_id: str | None = None
provider_type: Literal["tool", "trigger"]
class PluginPermissionSettingsPayload(BaseModel):
install_permission: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE
debug_permission: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE
class PluginAutoUpgradeSettingsPayload(BaseModel):
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting = (
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY
)
upgrade_time_of_day: int = 0
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exclude_plugins: list[str] = Field(default_factory=list)
include_plugins: list[str] = Field(default_factory=list)
class ParserPreferencesChange(BaseModel):
permission: PluginPermissionSettingsPayload
auto_upgrade: PluginAutoUpgradeSettingsPayload
class ParserExcludePlugin(BaseModel):
plugin_id: str
class ParserReadme(BaseModel):
plugin_unique_identifier: str
language: str = Field(default="en-US")
reg(ParserLatest)
reg(ParserIcon)
reg(ParserAsset)
reg(ParserGithubUpload)
reg(ParserPluginIdentifiers)
reg(ParserGithubInstall)
reg(ParserPluginIdentifierQuery)
reg(ParserTasks)
reg(ParserMarketplaceUpgrade)
reg(ParserGithubUpgrade)
reg(ParserUninstall)
reg(ParserPermissionChange)
reg(ParserDynamicOptions)
reg(ParserPreferencesChange)
reg(ParserExcludePlugin)
reg(ParserReadme)
@console_ns.route("/workspaces/current/plugin/list/latest-versions") @console_ns.route("/workspaces/current/plugin/list/latest-versions")
class PluginListLatestVersionsApi(Resource): class PluginListLatestVersionsApi(Resource):
@console_ns.expect(console_ns.models[ParserLatest.__name__]) @api.expect(parser_latest)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
args = ParserLatest.model_validate(console_ns.payload) args = parser_latest.parse_args()
try: try:
versions = PluginService.list_latest_versions(args.plugin_ids) versions = PluginService.list_latest_versions(args["plugin_ids"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
return jsonable_encoder({"versions": versions}) return jsonable_encoder({"versions": versions})
parser_ids = reqparse.RequestParser().add_argument("plugin_ids", type=list, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/list/installations/ids") @console_ns.route("/workspaces/current/plugin/list/installations/ids")
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@console_ns.expect(console_ns.models[ParserLatest.__name__]) @api.expect(parser_ids)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserLatest.model_validate(console_ns.payload) args = parser_ids.parse_args()
try: try:
plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
return jsonable_encoder({"plugins": plugins}) return jsonable_encoder({"plugins": plugins})
parser_icon = (
reqparse.RequestParser()
.add_argument("tenant_id", type=str, required=True, location="args")
.add_argument("filename", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/icon") @console_ns.route("/workspaces/current/plugin/icon")
class PluginIconApi(Resource): class PluginIconApi(Resource):
@console_ns.expect(console_ns.models[ParserIcon.__name__]) @api.expect(parser_icon)
@setup_required @setup_required
def get(self): def get(self):
args = ParserIcon.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_icon.parse_args()
try: try:
icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -242,16 +128,18 @@ class PluginIconApi(Resource):
@console_ns.route("/workspaces/current/plugin/asset") @console_ns.route("/workspaces/current/plugin/asset")
class PluginAssetApi(Resource): class PluginAssetApi(Resource):
@console_ns.expect(console_ns.models[ParserAsset.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
args = ParserAsset.model_validate(request.args.to_dict(flat=True)) # type: ignore req = reqparse.RequestParser()
req.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
req.add_argument("file_name", type=str, required=True, location="args")
args = req.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
try: try:
binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) binary = PluginService.extract_asset(tenant_id, args["plugin_unique_identifier"], args["file_name"])
return send_file(io.BytesIO(binary), mimetype="application/octet-stream") return send_file(io.BytesIO(binary), mimetype="application/octet-stream")
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -281,9 +169,17 @@ class PluginUploadFromPkgApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
parser_github = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upload/github") @console_ns.route("/workspaces/current/plugin/upload/github")
class PluginUploadFromGithubApi(Resource): class PluginUploadFromGithubApi(Resource):
@console_ns.expect(console_ns.models[ParserGithubUpload.__name__]) @api.expect(parser_github)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -291,10 +187,10 @@ class PluginUploadFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserGithubUpload.model_validate(console_ns.payload) args = parser_github.parse_args()
try: try:
response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -325,28 +221,47 @@ class PluginUploadFromBundleApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
parser_pkg = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/pkg") @console_ns.route("/workspaces/current/plugin/install/pkg")
class PluginInstallFromPkgApi(Resource): class PluginInstallFromPkgApi(Resource):
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__]) @api.expect(parser_pkg)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserPluginIdentifiers.model_validate(console_ns.payload) args = parser_pkg.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try: try:
response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
return jsonable_encoder(response) return jsonable_encoder(response)
parser_githubapi = (
reqparse.RequestParser()
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/install/github") @console_ns.route("/workspaces/current/plugin/install/github")
class PluginInstallFromGithubApi(Resource): class PluginInstallFromGithubApi(Resource):
@console_ns.expect(console_ns.models[ParserGithubInstall.__name__]) @api.expect(parser_githubapi)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -354,15 +269,15 @@ class PluginInstallFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserGithubInstall.model_validate(console_ns.payload) args = parser_githubapi.parse_args()
try: try:
response = PluginService.install_from_github( response = PluginService.install_from_github(
tenant_id, tenant_id,
args.plugin_unique_identifier, args["plugin_unique_identifier"],
args.repo, args["repo"],
args.version, args["version"],
args.package, args["package"],
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -370,9 +285,14 @@ class PluginInstallFromGithubApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
parser_marketplace = reqparse.RequestParser().add_argument(
"plugin_unique_identifiers", type=list, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/install/marketplace") @console_ns.route("/workspaces/current/plugin/install/marketplace")
class PluginInstallFromMarketplaceApi(Resource): class PluginInstallFromMarketplaceApi(Resource):
@console_ns.expect(console_ns.models[ParserPluginIdentifiers.__name__]) @api.expect(parser_marketplace)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -380,33 +300,43 @@ class PluginInstallFromMarketplaceApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserPluginIdentifiers.model_validate(console_ns.payload) args = parser_marketplace.parse_args()
# check if all plugin_unique_identifiers are valid string
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
if not isinstance(plugin_unique_identifier, str):
raise ValueError("Invalid plugin unique identifier")
try: try:
response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
return jsonable_encoder(response) return jsonable_encoder(response)
parser_pkgapi = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/marketplace/pkg") @console_ns.route("/workspaces/current/plugin/marketplace/pkg")
class PluginFetchMarketplacePkgApi(Resource): class PluginFetchMarketplacePkgApi(Resource):
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__]) @api.expect(parser_pkgapi)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_pkgapi.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
{ {
"manifest": PluginService.fetch_marketplace_pkg( "manifest": PluginService.fetch_marketplace_pkg(
tenant_id, tenant_id,
args.plugin_unique_identifier, args["plugin_unique_identifier"],
) )
} }
) )
@ -414,9 +344,14 @@ class PluginFetchMarketplacePkgApi(Resource):
raise ValueError(e) raise ValueError(e)
parser_fetch = reqparse.RequestParser().add_argument(
"plugin_unique_identifier", type=str, required=True, location="args"
)
@console_ns.route("/workspaces/current/plugin/fetch-manifest") @console_ns.route("/workspaces/current/plugin/fetch-manifest")
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@console_ns.expect(console_ns.models[ParserPluginIdentifierQuery.__name__]) @api.expect(parser_fetch)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -424,19 +359,30 @@ class PluginFetchManifestApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserPluginIdentifierQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_fetch.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
{"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} {
"manifest": PluginService.fetch_plugin_manifest(
tenant_id, args["plugin_unique_identifier"]
).model_dump()
}
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
parser_tasks = (
reqparse.RequestParser()
.add_argument("page", type=int, required=True, location="args")
.add_argument("page_size", type=int, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/tasks") @console_ns.route("/workspaces/current/plugin/tasks")
class PluginFetchInstallTasksApi(Resource): class PluginFetchInstallTasksApi(Resource):
@console_ns.expect(console_ns.models[ParserTasks.__name__]) @api.expect(parser_tasks)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -444,10 +390,12 @@ class PluginFetchInstallTasksApi(Resource):
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserTasks.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_tasks.parse_args()
try: try:
return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) return jsonable_encoder(
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
)
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -512,9 +460,16 @@ class PluginDeleteInstallTaskItemApi(Resource):
raise ValueError(e) raise ValueError(e)
parser_marketplace_api = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @console_ns.route("/workspaces/current/plugin/upgrade/marketplace")
class PluginUpgradeFromMarketplaceApi(Resource): class PluginUpgradeFromMarketplaceApi(Resource):
@console_ns.expect(console_ns.models[ParserMarketplaceUpgrade.__name__]) @api.expect(parser_marketplace_api)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -522,21 +477,31 @@ class PluginUpgradeFromMarketplaceApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserMarketplaceUpgrade.model_validate(console_ns.payload) args = parser_marketplace_api.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
PluginService.upgrade_plugin_with_marketplace( PluginService.upgrade_plugin_with_marketplace(
tenant_id, args.original_plugin_unique_identifier, args.new_plugin_unique_identifier tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
) )
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
parser_github_post = (
reqparse.RequestParser()
.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
.add_argument("repo", type=str, required=True, location="json")
.add_argument("version", type=str, required=True, location="json")
.add_argument("package", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/upgrade/github") @console_ns.route("/workspaces/current/plugin/upgrade/github")
class PluginUpgradeFromGithubApi(Resource): class PluginUpgradeFromGithubApi(Resource):
@console_ns.expect(console_ns.models[ParserGithubUpgrade.__name__]) @api.expect(parser_github_post)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -544,44 +509,56 @@ class PluginUpgradeFromGithubApi(Resource):
def post(self): def post(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserGithubUpgrade.model_validate(console_ns.payload) args = parser_github_post.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
PluginService.upgrade_plugin_with_github( PluginService.upgrade_plugin_with_github(
tenant_id, tenant_id,
args.original_plugin_unique_identifier, args["original_plugin_unique_identifier"],
args.new_plugin_unique_identifier, args["new_plugin_unique_identifier"],
args.repo, args["repo"],
args.version, args["version"],
args.package, args["package"],
) )
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
parser_uninstall = reqparse.RequestParser().add_argument(
"plugin_installation_id", type=str, required=True, location="json"
)
@console_ns.route("/workspaces/current/plugin/uninstall") @console_ns.route("/workspaces/current/plugin/uninstall")
class PluginUninstallApi(Resource): class PluginUninstallApi(Resource):
@console_ns.expect(console_ns.models[ParserUninstall.__name__]) @api.expect(parser_uninstall)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(install_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
args = ParserUninstall.model_validate(console_ns.payload) args = parser_uninstall.parse_args()
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
try: try:
return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
parser_change_post = (
reqparse.RequestParser()
.add_argument("install_permission", type=str, required=True, location="json")
.add_argument("debug_permission", type=str, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/permission/change") @console_ns.route("/workspaces/current/plugin/permission/change")
class PluginChangePermissionApi(Resource): class PluginChangePermissionApi(Resource):
@console_ns.expect(console_ns.models[ParserPermissionChange.__name__]) @api.expect(parser_change_post)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -591,15 +568,14 @@ class PluginChangePermissionApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
args = ParserPermissionChange.model_validate(console_ns.payload) args = parser_change_post.parse_args()
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
tenant_id = current_tenant_id tenant_id = current_tenant_id
return { return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
"success": PluginPermissionService.change_permission(
tenant_id, args.install_permission, args.debug_permission
)
}
@console_ns.route("/workspaces/current/plugin/permission/fetch") @console_ns.route("/workspaces/current/plugin/permission/fetch")
@ -627,29 +603,43 @@ class PluginFetchPermissionApi(Resource):
) )
parser_dynamic = (
reqparse.RequestParser()
.add_argument("plugin_id", type=str, required=True, location="args")
.add_argument("provider", type=str, required=True, location="args")
.add_argument("action", type=str, required=True, location="args")
.add_argument("parameter", type=str, required=True, location="args")
.add_argument("credential_id", type=str, required=False, location="args")
.add_argument("provider_type", type=str, required=True, location="args")
)
@console_ns.route("/workspaces/current/plugin/parameters/dynamic-options") @console_ns.route("/workspaces/current/plugin/parameters/dynamic-options")
class PluginFetchDynamicSelectOptionsApi(Resource): class PluginFetchDynamicSelectOptionsApi(Resource):
@console_ns.expect(console_ns.models[ParserDynamicOptions.__name__]) @api.expect(parser_dynamic)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
# check if the user is admin or owner
current_user, tenant_id = current_account_with_tenant() current_user, tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
user_id = current_user.id user_id = current_user.id
args = ParserDynamicOptions.model_validate(request.args.to_dict(flat=True)) # type: ignore args = parser_dynamic.parse_args()
try: try:
options = PluginParameterService.get_dynamic_select_options( options = PluginParameterService.get_dynamic_select_options(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=args.plugin_id, plugin_id=args["plugin_id"],
provider=args.provider, provider=args["provider"],
action=args.action, action=args["action"],
parameter=args.parameter, parameter=args["parameter"],
credential_id=args.credential_id, credential_id=args["credential_id"],
provider_type=args.provider_type, provider_type=args["provider_type"],
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
raise ValueError(e) raise ValueError(e)
@ -657,9 +647,16 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options}) return jsonable_encoder({"options": options})
parser_change = (
reqparse.RequestParser()
.add_argument("permission", type=dict, required=True, location="json")
.add_argument("auto_upgrade", type=dict, required=True, location="json")
)
@console_ns.route("/workspaces/current/plugin/preferences/change") @console_ns.route("/workspaces/current/plugin/preferences/change")
class PluginChangePreferencesApi(Resource): class PluginChangePreferencesApi(Resource):
@console_ns.expect(console_ns.models[ParserPreferencesChange.__name__]) @api.expect(parser_change)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -668,20 +665,22 @@ class PluginChangePreferencesApi(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
args = ParserPreferencesChange.model_validate(console_ns.payload) args = parser_change.parse_args()
permission = args.permission permission = args["permission"]
install_permission = permission.install_permission install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = permission.debug_permission debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
auto_upgrade = args.auto_upgrade auto_upgrade = args["auto_upgrade"]
strategy_setting = auto_upgrade.strategy_setting strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
upgrade_time_of_day = auto_upgrade.upgrade_time_of_day auto_upgrade.get("strategy_setting", "fix_only")
upgrade_mode = auto_upgrade.upgrade_mode )
exclude_plugins = auto_upgrade.exclude_plugins upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
include_plugins = auto_upgrade.include_plugins upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])
# set permission # set permission
set_permission_result = PluginPermissionService.change_permission( set_permission_result = PluginPermissionService.change_permission(
@ -746,9 +745,12 @@ class PluginFetchPreferencesApi(Resource):
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict}) return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
parser_exclude = reqparse.RequestParser().add_argument("plugin_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude") @console_ns.route("/workspaces/current/plugin/preferences/autoupgrade/exclude")
class PluginAutoUpgradeExcludePluginApi(Resource): class PluginAutoUpgradeExcludePluginApi(Resource):
@console_ns.expect(console_ns.models[ParserExcludePlugin.__name__]) @api.expect(parser_exclude)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -756,20 +758,26 @@ class PluginAutoUpgradeExcludePluginApi(Resource):
# exclude one single plugin # exclude one single plugin
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserExcludePlugin.model_validate(console_ns.payload) args = parser_exclude.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args.plugin_id)}) return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
@console_ns.route("/workspaces/current/plugin/readme") @console_ns.route("/workspaces/current/plugin/readme")
class PluginReadmeApi(Resource): class PluginReadmeApi(Resource):
@console_ns.expect(console_ns.models[ParserReadme.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self):
_, tenant_id = current_account_with_tenant() _, tenant_id = current_account_with_tenant()
args = ParserReadme.model_validate(request.args.to_dict(flat=True)) # type: ignore parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
parser.add_argument("language", type=str, required=False, location="args")
args = parser.parse_args()
return jsonable_encoder( return jsonable_encoder(
{"readme": PluginService.fetch_plugin_readme(tenant_id, args.plugin_unique_identifier, args.language)} {
"readme": PluginService.fetch_plugin_readme(
tenant_id, args["plugin_unique_identifier"], args.get("language", "en-US")
)
}
) )

View File

@ -10,11 +10,10 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
enterprise_license_required, enterprise_license_required,
is_admin_or_owner_required,
setup_required, setup_required,
) )
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
@ -65,7 +64,7 @@ parser_tool = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-providers") @console_ns.route("/workspaces/current/tool-providers")
class ToolProviderListApi(Resource): class ToolProviderListApi(Resource):
@console_ns.expect(parser_tool) @api.expect(parser_tool)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -113,13 +112,14 @@ parser_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/delete")
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@console_ns.expect(parser_delete) @api.expect(parser_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
_, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_delete.parse_args() args = parser_delete.parse_args()
@ -140,7 +140,7 @@ parser_add = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/add")
class ToolBuiltinProviderAddApi(Resource): class ToolBuiltinProviderAddApi(Resource):
@console_ns.expect(parser_add) @api.expect(parser_add)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -174,13 +174,16 @@ parser_update = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/update")
class ToolBuiltinProviderUpdateApi(Resource): class ToolBuiltinProviderUpdateApi(Resource):
@console_ns.expect(parser_update) @api.expect(parser_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_update.parse_args() args = parser_update.parse_args()
@ -236,14 +239,16 @@ parser_api_add = (
@console_ns.route("/workspaces/current/tool-provider/api/add") @console_ns.route("/workspaces/current/tool-provider/api/add")
class ToolApiProviderAddApi(Resource): class ToolApiProviderAddApi(Resource):
@console_ns.expect(parser_api_add) @api.expect(parser_api_add)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_add.parse_args() args = parser_api_add.parse_args()
@ -267,7 +272,7 @@ parser_remote = reqparse.RequestParser().add_argument("url", type=str, required=
@console_ns.route("/workspaces/current/tool-provider/api/remote") @console_ns.route("/workspaces/current/tool-provider/api/remote")
class ToolApiProviderGetRemoteSchemaApi(Resource): class ToolApiProviderGetRemoteSchemaApi(Resource):
@console_ns.expect(parser_remote) @api.expect(parser_remote)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -292,7 +297,7 @@ parser_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/tools") @console_ns.route("/workspaces/current/tool-provider/api/tools")
class ToolApiProviderListToolsApi(Resource): class ToolApiProviderListToolsApi(Resource):
@console_ns.expect(parser_tools) @api.expect(parser_tools)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -328,14 +333,16 @@ parser_api_update = (
@console_ns.route("/workspaces/current/tool-provider/api/update") @console_ns.route("/workspaces/current/tool-provider/api/update")
class ToolApiProviderUpdateApi(Resource): class ToolApiProviderUpdateApi(Resource):
@console_ns.expect(parser_api_update) @api.expect(parser_api_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_update.parse_args() args = parser_api_update.parse_args()
@ -362,14 +369,16 @@ parser_api_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/delete") @console_ns.route("/workspaces/current/tool-provider/api/delete")
class ToolApiProviderDeleteApi(Resource): class ToolApiProviderDeleteApi(Resource):
@console_ns.expect(parser_api_delete) @api.expect(parser_api_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_api_delete.parse_args() args = parser_api_delete.parse_args()
@ -386,7 +395,7 @@ parser_get = reqparse.RequestParser().add_argument("provider", type=str, require
@console_ns.route("/workspaces/current/tool-provider/api/get") @console_ns.route("/workspaces/current/tool-provider/api/get")
class ToolApiProviderGetApi(Resource): class ToolApiProviderGetApi(Resource):
@console_ns.expect(parser_get) @api.expect(parser_get)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -426,7 +435,7 @@ parser_schema = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/api/schema") @console_ns.route("/workspaces/current/tool-provider/api/schema")
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@console_ns.expect(parser_schema) @api.expect(parser_schema)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -451,7 +460,7 @@ parser_pre = (
@console_ns.route("/workspaces/current/tool-provider/api/test/pre") @console_ns.route("/workspaces/current/tool-provider/api/test/pre")
class ToolApiProviderPreviousTestApi(Resource): class ToolApiProviderPreviousTestApi(Resource):
@console_ns.expect(parser_pre) @api.expect(parser_pre)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -484,14 +493,16 @@ parser_create = (
@console_ns.route("/workspaces/current/tool-provider/workflow/create") @console_ns.route("/workspaces/current/tool-provider/workflow/create")
class ToolWorkflowProviderCreateApi(Resource): class ToolWorkflowProviderCreateApi(Resource):
@console_ns.expect(parser_create) @api.expect(parser_create)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_create.parse_args() args = parser_create.parse_args()
@ -525,13 +536,16 @@ parser_workflow_update = (
@console_ns.route("/workspaces/current/tool-provider/workflow/update") @console_ns.route("/workspaces/current/tool-provider/workflow/update")
class ToolWorkflowProviderUpdateApi(Resource): class ToolWorkflowProviderUpdateApi(Resource):
@console_ns.expect(parser_workflow_update) @api.expect(parser_workflow_update)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_workflow_update.parse_args() args = parser_workflow_update.parse_args()
@ -560,14 +574,16 @@ parser_workflow_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/delete") @console_ns.route("/workspaces/current/tool-provider/workflow/delete")
class ToolWorkflowProviderDeleteApi(Resource): class ToolWorkflowProviderDeleteApi(Resource):
@console_ns.expect(parser_workflow_delete) @api.expect(parser_workflow_delete)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id user_id = user.id
args = parser_workflow_delete.parse_args() args = parser_workflow_delete.parse_args()
@ -588,7 +604,7 @@ parser_wf_get = (
@console_ns.route("/workspaces/current/tool-provider/workflow/get") @console_ns.route("/workspaces/current/tool-provider/workflow/get")
class ToolWorkflowProviderGetApi(Resource): class ToolWorkflowProviderGetApi(Resource):
@console_ns.expect(parser_wf_get) @api.expect(parser_wf_get)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -624,7 +640,7 @@ parser_wf_tools = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/workflow/tools") @console_ns.route("/workspaces/current/tool-provider/workflow/tools")
class ToolWorkflowProviderListToolApi(Resource): class ToolWorkflowProviderListToolApi(Resource):
@console_ns.expect(parser_wf_tools) @api.expect(parser_wf_tools)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -718,15 +734,18 @@ class ToolLabelsApi(Resource):
class ToolPluginOAuthApi(Resource): class ToolPluginOAuthApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name provider_name = tool_provider.provider_name
# todo check permission
user, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None: if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
@ -813,7 +832,7 @@ parser_default_cred = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/default-credential")
class ToolBuiltinProviderSetDefaultApi(Resource): class ToolBuiltinProviderSetDefaultApi(Resource):
@console_ns.expect(parser_default_cred) @api.expect(parser_default_cred)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -834,15 +853,17 @@ parser_custom = (
@console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client") @console_ns.route("/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolOAuthCustomClient(Resource): class ToolOAuthCustomClient(Resource):
@console_ns.expect(parser_custom) @api.expect(parser_custom)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider):
args = parser_custom.parse_args() args = parser_custom.parse_args()
_, tenant_id = current_account_with_tenant() user, tenant_id = current_account_with_tenant()
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params( return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -932,7 +953,7 @@ parser_mcp_delete = reqparse.RequestParser().add_argument(
@console_ns.route("/workspaces/current/tool-provider/mcp") @console_ns.route("/workspaces/current/tool-provider/mcp")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@console_ns.expect(parser_mcp) @api.expect(parser_mcp)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -962,7 +983,7 @@ class ToolProviderMCPApi(Resource):
) )
return jsonable_encoder(result) return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put) @api.expect(parser_mcp_put)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1001,7 +1022,7 @@ class ToolProviderMCPApi(Resource):
) )
return {"result": "success"} return {"result": "success"}
@console_ns.expect(parser_mcp_delete) @api.expect(parser_mcp_delete)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1024,7 +1045,7 @@ parser_auth = (
@console_ns.route("/workspaces/current/tool-provider/mcp/auth") @console_ns.route("/workspaces/current/tool-provider/mcp/auth")
class ToolMCPAuthApi(Resource): class ToolMCPAuthApi(Resource):
@console_ns.expect(parser_auth) @api.expect(parser_auth)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -1065,13 +1086,7 @@ class ToolMCPAuthApi(Resource):
return {"result": "success"} return {"result": "success"}
except MCPAuthError as e: except MCPAuthError as e:
try: try:
# Pass the extracted OAuth metadata hints to auth() auth_result = auth(provider_entity, args.get("authorization_code"))
auth_result = auth(
provider_entity,
args.get("authorization_code"),
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result) response = service.execute_auth_actions(auth_result)
@ -1081,7 +1096,7 @@ class ToolMCPAuthApi(Resource):
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e: except MCPError as e:
with Session(db.engine) as session, session.begin(): with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session) service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id) service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
@ -1142,7 +1157,7 @@ parser_cb = (
@console_ns.route("/mcp/oauth/callback") @console_ns.route("/mcp/oauth/callback")
class ToolMCPCallbackApi(Resource): class ToolMCPCallbackApi(Resource):
@console_ns.expect(parser_cb) @api.expect(parser_cb)
def get(self): def get(self):
args = parser_cb.parse_args() args = parser_cb.parse_args()
state_key = args["state"] state_key = args["state"]

View File

@ -6,6 +6,8 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.plugin_daemon import CredentialType
@ -21,13 +23,9 @@ from services.trigger.trigger_provider_service import TriggerProviderService
from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService from services.trigger.trigger_subscription_builder_service import TriggerSubscriptionBuilderService
from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService from services.trigger.trigger_subscription_operator_service import TriggerSubscriptionOperatorService
from .. import console_ns
from ..wraps import account_initialization_required, is_admin_or_owner_required, setup_required
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/icon")
class TriggerProviderIconApi(Resource): class TriggerProviderIconApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -40,7 +38,6 @@ class TriggerProviderIconApi(Resource):
return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider) return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider)
@console_ns.route("/workspaces/current/triggers")
class TriggerProviderListApi(Resource): class TriggerProviderListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -53,7 +50,6 @@ class TriggerProviderListApi(Resource):
return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id)) return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id))
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/info")
class TriggerProviderInfoApi(Resource): class TriggerProviderInfoApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -68,16 +64,17 @@ class TriggerProviderInfoApi(Resource):
) )
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
class TriggerSubscriptionListApi(Resource): class TriggerSubscriptionListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""List all trigger subscriptions for the current tenant's provider""" """List all trigger subscriptions for the current tenant's provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
return jsonable_encoder( return jsonable_encoder(
@ -92,25 +89,20 @@ class TriggerSubscriptionListApi(Resource):
raise raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource): class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
"""Add a new subscription instance for a trigger provider""" """Add a new subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_type", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
@ -127,9 +119,6 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
raise raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderGetApi(Resource): class TriggerSubscriptionBuilderGetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -141,28 +130,22 @@ class TriggerSubscriptionBuilderGetApi(Resource):
) )
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyApi(Resource): class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(parser_api)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Verify a subscription instance for a trigger provider""" """Verify a subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_api.parse_args() parser = reqparse.RequestParser()
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
# Use atomic update_and_verify to prevent race conditions # Use atomic update_and_verify to prevent race conditions
@ -180,24 +163,7 @@ class TriggerSubscriptionBuilderVerifyApi(Resource):
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource): class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@ -207,7 +173,16 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account) assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
args = parser_update_api.parse_args() parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
return jsonable_encoder( return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder( TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@ -227,9 +202,6 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
raise raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderLogsApi(Resource): class TriggerSubscriptionBuilderLogsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -248,20 +220,28 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
raise raise
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource): class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider, subscription_builder_id): def post(self, provider, subscription_builder_id):
"""Build a subscription instance for a trigger provider""" """Build a subscription instance for a trigger provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
args = parser_update_api.parse_args() if not user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
# The name of the subscription builder
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
parser.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
parser.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
# Use atomic update_and_build to prevent race conditions # Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder( TriggerSubscriptionBuilderService.update_and_build_builder(
@ -281,18 +261,17 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
@console_ns.route(
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
class TriggerSubscriptionDeleteApi(Resource): class TriggerSubscriptionDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, subscription_id: str): def post(self, subscription_id: str):
"""Delete a subscription instance""" """Delete a subscription instance"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
with Session(db.engine) as session: with Session(db.engine) as session:
@ -317,7 +296,6 @@ class TriggerSubscriptionDeleteApi(Resource):
raise raise
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize")
class TriggerOAuthAuthorizeApi(Resource): class TriggerOAuthAuthorizeApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -401,7 +379,6 @@ class TriggerOAuthAuthorizeApi(Resource):
raise raise
@console_ns.route("/oauth/plugin/<path:provider>/trigger/callback")
class TriggerOAuthCallbackApi(Resource): class TriggerOAuthCallbackApi(Resource):
@setup_required @setup_required
def get(self, provider): def get(self, provider):
@ -466,23 +443,17 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource): class TriggerOAuthClientManageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
"""Get OAuth client configuration for a provider""" """Get OAuth client configuration for a provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
@ -520,17 +491,21 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e) logger.exception("Error getting OAuth client", exc_info=e)
raise raise
@console_ns.expect(parser_oauth_client)
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
"""Configure custom OAuth client for a provider""" """Configure custom OAuth client for a provider"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
args = parser_oauth_client.parse_args() parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
@ -549,12 +524,14 @@ class TriggerOAuthClientManageApi(Resource):
@setup_required @setup_required
@login_required @login_required
@is_admin_or_owner_required
@account_initialization_required @account_initialization_required
def delete(self, provider): def delete(self, provider):
"""Remove custom OAuth client configuration""" """Remove custom OAuth client configuration"""
user = current_user user = current_user
assert isinstance(user, Account)
assert user.current_tenant_id is not None assert user.current_tenant_id is not None
if not user.is_admin_or_owner:
raise Forbidden()
try: try:
provider_id = TriggerProviderID(provider) provider_id = TriggerProviderID(provider)
@ -568,3 +545,48 @@ class TriggerOAuthClientManageApi(Resource):
except Exception as e: except Exception as e:
logger.exception("Error removing OAuth client", exc_info=e) logger.exception("Error removing OAuth client", exc_info=e)
raise raise
# Trigger Subscription
api.add_resource(TriggerProviderIconApi, "/workspaces/current/trigger-provider/<path:provider>/icon")
api.add_resource(TriggerProviderListApi, "/workspaces/current/triggers")
api.add_resource(TriggerProviderInfoApi, "/workspaces/current/trigger-provider/<path:provider>/info")
api.add_resource(TriggerSubscriptionListApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/list")
api.add_resource(
TriggerSubscriptionDeleteApi,
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/delete",
)
# Trigger Subscription Builder
api.add_resource(
TriggerSubscriptionBuilderCreateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
api.add_resource(
TriggerSubscriptionBuilderGetApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderUpdateApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderVerifyApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderBuildApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
api.add_resource(
TriggerSubscriptionBuilderLogsApi,
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/logs/<path:subscription_builder_id>",
)
# OAuth
api.add_resource(
TriggerOAuthAuthorizeApi, "/workspaces/current/trigger-provider/<path:provider>/subscriptions/oauth/authorize"
)
api.add_resource(TriggerOAuthCallbackApi, "/oauth/plugin/<path:provider>/trigger/callback")
api.add_resource(TriggerOAuthClientManageApi, "/workspaces/current/trigger-provider/<path:provider>/oauth/client")

View File

@ -1,8 +1,7 @@
import logging import logging
from flask import request from flask import request
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
@ -14,7 +13,7 @@ from controllers.common.errors import (
TooManyFilesError, TooManyFilesError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.console import console_ns from controllers.console import api, console_ns
from controllers.console.admin import admin_required from controllers.console.admin import admin_required
from controllers.console.error import AccountNotLinkTenantError from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import ( from controllers.console.wraps import (
@ -33,36 +32,8 @@ from services.file_service import FileService
from services.workspace_service import WorkspaceService from services.workspace_service import WorkspaceService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkspaceListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
class SwitchWorkspacePayload(BaseModel):
tenant_id: str
class WorkspaceCustomConfigPayload(BaseModel):
remove_webapp_brand: bool | None = None
replace_webapp_logo: str | None = None
class WorkspaceInfoPayload(BaseModel):
name: str
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(WorkspaceListQuery)
reg(SwitchWorkspacePayload)
reg(WorkspaceCustomConfigPayload)
reg(WorkspaceInfoPayload)
provider_fields = { provider_fields = {
"provider_name": fields.String, "provider_name": fields.String,
"provider_type": fields.String, "provider_type": fields.String,
@ -124,15 +95,18 @@ class TenantListApi(Resource):
@console_ns.route("/all-workspaces") @console_ns.route("/all-workspaces")
class WorkspaceListApi(Resource): class WorkspaceListApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceListQuery.__name__])
@setup_required @setup_required
@admin_required @admin_required
def get(self): def get(self):
payload = request.args.to_dict(flat=True) # type: ignore parser = (
args = WorkspaceListQuery.model_validate(payload) reqparse.RequestParser()
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
stmt = select(Tenant).order_by(Tenant.created_at.desc()) stmt = select(Tenant).order_by(Tenant.created_at.desc())
tenants = db.paginate(select=stmt, page=args.page, per_page=args.limit, error_out=False) tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
has_more = False has_more = False
if tenants.has_next: if tenants.has_next:
@ -141,8 +115,8 @@ class WorkspaceListApi(Resource):
return { return {
"data": marshal(tenants.items, workspace_fields), "data": marshal(tenants.items, workspace_fields),
"has_more": has_more, "has_more": has_more,
"limit": args.limit, "limit": args["limit"],
"page": args.page, "page": args["page"],
"total": tenants.total, "total": tenants.total,
}, 200 }, 200
@ -154,7 +128,7 @@ class TenantApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(tenant_fields) @marshal_with(tenant_fields)
def post(self): def get(self):
if request.path == "/info": if request.path == "/info":
logger.warning("Deprecated URL /info was used.") logger.warning("Deprecated URL /info was used.")
@ -176,24 +150,26 @@ class TenantApi(Resource):
return WorkspaceService.get_tenant_info(tenant), 200 return WorkspaceService.get_tenant_info(tenant), 200
parser_switch = reqparse.RequestParser().add_argument("tenant_id", type=str, required=True, location="json")
@console_ns.route("/workspaces/switch") @console_ns.route("/workspaces/switch")
class SwitchWorkspaceApi(Resource): class SwitchWorkspaceApi(Resource):
@console_ns.expect(console_ns.models[SwitchWorkspacePayload.__name__]) @api.expect(parser_switch)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self):
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_switch.parse_args()
args = SwitchWorkspacePayload.model_validate(payload)
# check if tenant_id is valid, 403 if not # check if tenant_id is valid, 403 if not
try: try:
TenantService.switch_tenant(current_user, args.tenant_id) TenantService.switch_tenant(current_user, args["tenant_id"])
except Exception: except Exception:
raise AccountNotLinkTenantError("Account not link tenant") raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
if new_tenant is None: if new_tenant is None:
raise ValueError("Tenant not found") raise ValueError("Tenant not found")
@ -202,21 +178,24 @@ class SwitchWorkspaceApi(Resource):
@console_ns.route("/workspaces/custom-config") @console_ns.route("/workspaces/custom-config")
class CustomConfigWorkspaceApi(Resource): class CustomConfigWorkspaceApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceCustomConfigPayload.__name__])
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom") @cloud_edition_billing_resource_check("workspace_custom")
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} parser = (
args = WorkspaceCustomConfigPayload.model_validate(payload) reqparse.RequestParser()
.add_argument("remove_webapp_brand", type=bool, location="json")
.add_argument("replace_webapp_logo", type=str, location="json")
)
args = parser.parse_args()
tenant = db.get_or_404(Tenant, current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args.remove_webapp_brand, "remove_webapp_brand": args["remove_webapp_brand"],
"replace_webapp_logo": args.replace_webapp_logo "replace_webapp_logo": args["replace_webapp_logo"]
if args.replace_webapp_logo is not None if args["replace_webapp_logo"] is not None
else tenant.custom_config_dict.get("replace_webapp_logo"), else tenant.custom_config_dict.get("replace_webapp_logo"),
} }
@ -266,22 +245,24 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201 return {"id": upload_file.id}, 201
parser_info = reqparse.RequestParser().add_argument("name", type=str, required=True, location="json")
@console_ns.route("/workspaces/info") @console_ns.route("/workspaces/info")
class WorkspaceInfoApi(Resource): class WorkspaceInfoApi(Resource):
@console_ns.expect(console_ns.models[WorkspaceInfoPayload.__name__]) @api.expect(parser_info)
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
# Change workspace name # Change workspace name
def post(self): def post(self):
_, current_tenant_id = current_account_with_tenant() _, current_tenant_id = current_account_with_tenant()
payload = console_ns.payload or {} args = parser_info.parse_args()
args = WorkspaceInfoPayload.model_validate(payload)
if not current_tenant_id: if not current_tenant_id:
raise ValueError("No current tenant") raise ValueError("No current tenant")
tenant = db.get_or_404(Tenant, current_tenant_id) tenant = db.get_or_404(Tenant, current_tenant_id)
tenant.name = args.name tenant.name = args["name"]
db.session.commit() db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}

View File

@ -315,19 +315,3 @@ def edit_permission_required(f: Callable[P, R]):
return f(*args, **kwargs) return f(*args, **kwargs)
return decorated_function return decorated_function
def is_admin_or_owner_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
raise Forbidden()
return f(*args, **kwargs)
return decorated_function

View File

@ -3,12 +3,14 @@ from typing import Literal
from flask import request from flask import request
from flask_restx import Api, Namespace, Resource, fields, reqparse from flask_restx import Api, Namespace, Resource, fields, reqparse
from flask_restx.api import HTTPStatus from flask_restx.api import HTTPStatus
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model from fields.annotation_fields import annotation_fields, build_annotation_model
from libs.login import current_user
from models import Account
from models.model import App from models.model import App
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
@ -159,10 +161,14 @@ class AnnotationUpdateDeleteApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns)) @service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str): def put(self, app_model: App, annotation_id):
"""Update an existing annotation.""" """Update an existing annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
args = annotation_create_parser.parse_args() args = annotation_create_parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation return annotation
@ -179,8 +185,13 @@ class AnnotationUpdateDeleteApi(Resource):
} }
) )
@validate_app_token @validate_app_token
@edit_permission_required def delete(self, app_model: App, annotation_id):
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation.""" """Delete an annotation."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 204 return {"result": "success"}, 204

View File

@ -17,6 +17,7 @@ from controllers.service_api.app.error import (
) )
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -29,7 +30,6 @@ from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -88,7 +88,7 @@ class CompletionApi(Resource):
This endpoint generates a completion based on the provided inputs and query. This endpoint generates a completion based on the provided inputs and query.
Supports both blocking and streaming response modes. Supports both blocking and streaming response modes.
""" """
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise AppUnavailableError() raise AppUnavailableError()
args = completion_parser.parse_args() args = completion_parser.parse_args()
@ -147,15 +147,10 @@ class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id: str): def post(self, app_model: App, end_user: EndUser, task_id: str):
"""Stop a running completion task.""" """Stop a running completion task."""
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise AppUnavailableError() raise AppUnavailableError()
AppTaskService.stop_task( AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
task_id=task_id,
invoke_from=InvokeFrom.SERVICE_API,
user_id=end_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -249,11 +244,6 @@ class ChatStopApi(Resource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppTaskService.stop_task( AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
task_id=task_id,
invoke_from=InvokeFrom.SERVICE_API,
user_id=end_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -5,7 +5,6 @@ from flask_restx import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
@ -620,9 +619,11 @@ class DatasetTagsApi(DatasetApiResource):
} }
) )
@validate_dataset_token @validate_dataset_token
@edit_permission_required
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
"""Delete a knowledge type tag.""" """Delete a knowledge type tag."""
assert isinstance(current_user, Account)
if not current_user.has_edit_permission:
raise Forbidden()
args = tag_delete_parser.parse_args() args = tag_delete_parser.parse_args()
TagService.delete_tag(args["tag_id"]) TagService.delete_tag(args["tag_id"])

View File

@ -1,10 +1,7 @@
import json import json
from typing import Self
from uuid import UUID
from flask import request from flask import request
from flask_restx import marshal, reqparse from flask_restx import marshal, reqparse
from pydantic import BaseModel, model_validator
from sqlalchemy import desc, select from sqlalchemy import desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -34,7 +31,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
# Define parsers for document operations # Define parsers for document operations
@ -54,26 +51,15 @@ document_text_create_parser = (
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") .add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
) )
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" document_text_update_parser = (
reqparse.RequestParser()
.add_argument("name", type=str, required=False, nullable=True, location="json")
class DocumentTextUpdate(BaseModel): .add_argument("text", type=str, required=False, nullable=True, location="json")
name: str | None = None .add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
text: str | None = None .add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
process_rule: ProcessRule | None = None .add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
doc_form: str = "text_model" .add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
doc_language: str = "English" )
retrieval_model: RetrievalModel | None = None
@model_validator(mode="after")
def check_text_and_name(self) -> Self:
if self.text is not None and self.name is None:
raise ValueError("name is required when text is provided")
return self
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@service_api_ns.route( @service_api_ns.route(
@ -174,7 +160,7 @@ class DocumentAddByTextApi(DatasetApiResource):
class DocumentUpdateByTextApi(DatasetApiResource): class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True) @service_api_ns.expect(document_text_update_parser)
@service_api_ns.doc("update_document_by_text") @service_api_ns.doc("update_document_by_text")
@service_api_ns.doc(description="Update an existing document by providing text content") @service_api_ns.doc(description="Update an existing document by providing text content")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@ -187,10 +173,12 @@ class DocumentUpdateByTextApi(DatasetApiResource):
) )
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset") @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): def post(self, tenant_id, dataset_id, document_id):
"""Update document by text.""" """Update document by text."""
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True) args = document_text_update_parser.parse_args()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -210,9 +198,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
if args.get("text"): if args["text"]:
text = args.get("text") text = args.get("text")
name = args.get("name") name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
if not current_user: if not current_user:
raise ValueError("current_user is required") raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text( upload_file = FileService(db.engine).upload_text(
@ -466,16 +456,12 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
status = request.args.get("status", default=None, type=str)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.where(Document.name.like(search))

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from flask import jsonify, request from flask import jsonify
from werkzeug.exceptions import NotFound, RequestEntityTooLarge from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp from controllers.trigger import bp
@ -28,14 +28,8 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config) webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
return webhook_trigger, workflow, node_config, webhook_data, None return webhook_trigger, workflow, node_config, webhook_data, None
except ValueError as e: except ValueError as e:
# Provide minimal context for error reporting without risking another parse failure # Fall back to raw extraction for error reporting
webhook_data = { webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
"method": request.method,
"headers": dict(request.headers),
"query_params": dict(request.args),
"body": {},
"files": {},
}
return webhook_trigger, workflow, node_config, webhook_data, str(e) return webhook_trigger, workflow, node_config, webhook_data, str(e)

View File

@ -17,6 +17,7 @@ from controllers.web.error import (
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
ModelCurrentlyNotSupportError, ModelCurrentlyNotSupportError,
@ -28,7 +29,6 @@ from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,7 +64,7 @@ class CompletionApi(WebApiResource):
} }
) )
def post(self, app_model, end_user): def post(self, app_model, end_user):
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
parser = ( parser = (
@ -125,15 +125,10 @@ class CompletionStopApi(WebApiResource):
} }
) )
def post(self, app_model, end_user, task_id): def post(self, app_model, end_user, task_id):
if app_model.mode != AppMode.COMPLETION: if app_model.mode != "completion":
raise NotCompletionAppError() raise NotCompletionAppError()
AppTaskService.stop_task( AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
task_id=task_id,
invoke_from=InvokeFrom.WEB_APP,
user_id=end_user.id,
app_mode=AppMode.value_of(app_model.mode),
)
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -239,11 +234,6 @@ class ChatStopApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError() raise NotChatAppError()
AppTaskService.stop_task( AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id)
task_id=task_id,
invoke_from=InvokeFrom.WEB_APP,
user_id=end_user.id,
app_mode=app_mode,
)
return {"result": "success"}, 200 return {"result": "success"}, 200

View File

@ -81,7 +81,6 @@ class LoginStatusApi(Resource):
) )
def get(self): def get(self):
app_code = request.args.get("app_code") app_code = request.args.get("app_code")
user_id = request.args.get("user_id")
token = extract_webapp_access_token(request) token = extract_webapp_access_token(request)
if not app_code: if not app_code:
return { return {
@ -104,7 +103,7 @@ class LoginStatusApi(Resource):
user_logged_in = False user_logged_in = False
try: try:
_ = decode_jwt_token(app_code=app_code, user_id=user_id) _ = decode_jwt_token(app_code=app_code)
app_logged_in = True app_logged_in = True
except Exception: except Exception:
app_logged_in = False app_logged_in = False

View File

@ -38,7 +38,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator return decorator
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None): def decode_jwt_token(app_code: str | None = None):
system_features = FeatureService.get_system_features() system_features = FeatureService.get_system_features()
if not app_code: if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE)) app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
@ -63,10 +63,6 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
if not end_user: if not end_user:
raise NotFound() raise NotFound()
# Validate user_id against end_user's session_id if provided
if user_id is not None and end_user.session_id != user_id:
raise Unauthorized("Authentication has expired.")
# for enterprise webapp auth # for enterprise webapp auth
app_web_auth_enabled = False app_web_auth_enabled = False
webapp_settings = None webapp_settings = None

View File

@ -112,7 +112,6 @@ class VariableEntity(BaseModel):
type: VariableEntityType type: VariableEntityType
required: bool = False required: bool = False
hide: bool = False hide: bool = False
default: Any = None
max_length: int | None = None max_length: int | None = None
options: Sequence[str] = Field(default_factory=list) options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)

View File

@ -62,8 +62,7 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.enums import WorkflowExecutionStatus from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@ -73,7 +72,7 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.workflow import Workflow, WorkflowNodeExecutionModel from models.workflow import Workflow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -581,7 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session: with self._database_session() as session:
# Save message # Save message
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp yield workflow_finish_resp
elif event.stopped_by in ( elif event.stopped_by in (
@ -591,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# When hitting input-moderation or annotation-reply, the workflow will not start # When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session: with self._database_session() as session:
# Save message # Save message
self._save_message(session=session, trace_manager=trace_manager) self._save_message(session=session)
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
@ -600,7 +599,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent, event: QueueAdvancedChatMessageEndEvent,
*, *,
graph_runtime_state: GraphRuntimeState | None = None, graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs, **kwargs,
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events.""" """Handle advanced chat message end events."""
@ -618,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message # Save message
with self._database_session() as session: with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager) self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
@ -772,13 +770,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message( def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session) message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer # If there are assistant files, remove markdown image links from answer
@ -817,14 +809,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
metadata = self._task_state.metadata.model_dump() metadata = self._task_state.metadata.model_dump()
message.message_metadata = json.dumps(jsonable_encoder(metadata)) message.message_metadata = json.dumps(jsonable_encoder(metadata))
# Extract model provider and model_id from workflow node executions for tracing
if message.workflow_run_id:
model_info = self._extract_model_info_from_workflow(session, message.workflow_run_id)
if model_info:
message.model_provider = model_info.get("provider")
message.model_id = model_info.get("model")
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=message.id, message_id=message.id,
@ -842,68 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
] ]
session.add_all(message_files) session.add_all(message_files)
# Trigger MESSAGE_TRACE for tracing integrations
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
)
def _extract_model_info_from_workflow(self, session: Session, workflow_run_id: str) -> dict[str, str] | None:
"""
Extract model provider and model_id from workflow node executions.
Returns dict with 'provider' and 'model' keys, or None if not found.
"""
try:
# Query workflow node executions for LLM or Agent nodes
stmt = (
select(WorkflowNodeExecutionModel)
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.where(WorkflowNodeExecutionModel.node_type.in_(["llm", "agent"]))
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.limit(1)
)
node_execution = session.scalar(stmt)
if not node_execution:
return None
# Try to extract from execution_metadata for agent nodes
if node_execution.execution_metadata:
try:
metadata = json.loads(node_execution.execution_metadata)
agent_log = metadata.get("agent_log", [])
# Look for the first agent thought with provider info
for log_entry in agent_log:
entry_metadata = log_entry.get("metadata", {})
provider_str = entry_metadata.get("provider")
if provider_str:
# Parse format like "langgenius/deepseek/deepseek"
parts = provider_str.split("/")
if len(parts) >= 3:
return {"provider": parts[1], "model": parts[2]}
elif len(parts) == 2:
return {"provider": parts[0], "model": parts[1]}
except (json.JSONDecodeError, KeyError, AttributeError) as e:
logger.debug("Failed to parse execution_metadata: %s", e)
# Try to extract from process_data for llm nodes
if node_execution.process_data:
try:
process_data = json.loads(node_execution.process_data)
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider and model:
return {"provider": provider, "model": model}
except (json.JSONDecodeError, KeyError) as e:
logger.debug("Failed to parse process_data: %s", e)
return None
except Exception as e:
logger.warning("Failed to extract model info from workflow: %s", e)
return None
def _seed_graph_runtime_state_from_queue_manager(self) -> None: def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present.""" """Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@ -93,11 +93,7 @@ class BaseAppGenerator:
if value is None: if value is None:
if variable_entity.required: if variable_entity.required:
raise ValueError(f"{variable_entity.variable} is required in input form") raise ValueError(f"{variable_entity.variable} is required in input form")
# Use default value and continue validation to ensure type conversion return value
value = variable_entity.default
# If default is also None, return None directly
if value is None:
return None
if variable_entity.type in { if variable_entity.type in {
VariableEntityType.TEXT_INPUT, VariableEntityType.TEXT_INPUT,
@ -155,17 +151,8 @@ class BaseAppGenerator:
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
) )
case VariableEntityType.CHECKBOX: case VariableEntityType.CHECKBOX:
if isinstance(value, str): if not isinstance(value, bool):
normalized_value = value.strip().lower() raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value")
if normalized_value in {"true", "1", "yes", "on"}:
value = True
elif normalized_value in {"false", "0", "no", "off"}:
value = False
elif isinstance(value, (int, float)):
if value == 1:
value = True
elif value == 0:
value = False
case _: case _:
raise AssertionError("this statement should be unreachable.") raise AssertionError("this statement should be unreachable.")

View File

@ -163,7 +163,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type=datasource_type, datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info), datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id, datasource_node_id=start_node_id,
input_data=dict(inputs), input_data=inputs,
pipeline_id=pipeline.id, pipeline_id=pipeline.id,
created_by=user.id, created_by=user.id,
) )

View File

@ -145,8 +145,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
**extract_external_trace_id_from_args(args), **extract_external_trace_id_from_args(args),
} }
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args # for trigger debug run, not prepare user inputs
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args): if self._should_prepare_user_inputs(args):
inputs = self._prepare_user_inputs( inputs = self._prepare_user_inputs(
user_inputs=inputs, user_inputs=inputs,

View File

@ -258,10 +258,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state) run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response( start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id, workflow_run_id=run_id,
@ -418,6 +414,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state, graph_runtime_state=validated_state,
) )
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp yield workflow_finish_resp
def _handle_workflow_partial_success_event( def _handle_workflow_partial_success_event(
@ -438,6 +437,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
graph_runtime_state=validated_state, graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count, exceptions_count=event.exceptions_count,
) )
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp yield workflow_finish_resp
def _handle_workflow_failed_and_stop_events( def _handle_workflow_failed_and_stop_events(
@ -468,6 +471,10 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
error=error, error=error,
exceptions_count=exceptions_count, exceptions_count=exceptions_count,
) )
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp yield workflow_finish_resp
def _handle_text_chunk_event( def _handle_text_chunk_event(
@ -637,17 +644,17 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if not workflow_run_id: if not workflow_run_id:
return return
workflow_app_log = WorkflowAppLog( workflow_app_log = WorkflowAppLog()
tenant_id=self._application_generate_entity.app_config.tenant_id, workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
app_id=self._application_generate_entity.app_config.app_id, workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_id=self._workflow.id, workflow_app_log.workflow_id = self._workflow.id
workflow_run_id=workflow_run_id, workflow_app_log.workflow_run_id = workflow_run_id
created_from=created_from.value, workflow_app_log.created_from = created_from.value
created_by_role=self._created_by_role, workflow_app_log.created_by_role = self._created_by_role
created_by=self._user_id, workflow_app_log.created_by = self._user_id
)
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: list[str] | None = None self, text: str, from_variable_selector: list[str] | None = None

View File

@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Optional
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file import File, FileUploadConfig from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
class InvokeFrom(StrEnum): class InvokeFrom(StrEnum):
""" """
@ -275,8 +275,10 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
start_node_id: str | None = None start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
# Rebuild models that use forward references
AppGenerateEntity.model_rebuild() AppGenerateEntity.model_rebuild()
EasyUIBasedAppGenerateEntity.model_rebuild() EasyUIBasedAppGenerateEntity.model_rebuild()
ConversationAppGenerateEntity.model_rebuild() ConversationAppGenerateEntity.model_rebuild()

View File

@ -40,9 +40,6 @@ class EasyUITaskState(TaskState):
""" """
llm_result: LLMResult llm_result: LLMResult
first_token_time: float | None = None
last_token_time: float | None = None
is_streaming_response: bool = False
class WorkflowTaskState(TaskState): class WorkflowTaskState(TaskState):

View File

@ -118,7 +118,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
state_owner_user_id=self._state_owner_user_id, state_owner_user_id=self._state_owner_user_id,
state=state.dumps(), state=state.dumps(),
pause_reasons=event.reasons,
) )
def on_graph_end(self, error: Exception | None) -> None: def on_graph_end(self, error: Exception | None) -> None:

Some files were not shown because too many files have changed in this diff Show More