mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 20:19:28 +08:00
Compare commits
2 Commits
deploy/tri
...
chore/work
| Author | SHA1 | Date | |
|---|---|---|---|
| bf20f3aa8b | |||
| c90e564d99 |
@ -1,6 +0,0 @@
|
||||
# Cursor Rules for Dify Project
|
||||
|
||||
## Automated Test Generation
|
||||
|
||||
- Use `web/testing/testing.md` as the canonical instruction set for generating frontend automated tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
@ -11,7 +11,7 @@
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
},
|
||||
"ghcr.io/devcontainers-extra/features/npm-package:1": {
|
||||
"ghcr.io/devcontainers-contrib/features/npm-package:1": {
|
||||
"package": "typescript",
|
||||
"version": "latest"
|
||||
},
|
||||
|
||||
@ -6,10 +6,11 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
echo "alias stop-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down\"" >> ~/.bashrc
|
||||
|
||||
source /home/vscode/.bashrc
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ trim_trailing_whitespace = false
|
||||
|
||||
# Matches multiple files with brace expansion notation
|
||||
# Set default charset
|
||||
[*.{js,jsx,ts,tsx,mjs}]
|
||||
[*.{js,tsx}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
|
||||
226
.github/CODEOWNERS
vendored
226
.github/CODEOWNERS
vendored
@ -1,226 +0,0 @@
|
||||
# CODEOWNERS
|
||||
# This file defines code ownership for the Dify project.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
# Owners can be @username, @org/team-name, or email addresses.
|
||||
# For more information, see: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# Backend (default owner, more specific rules below will override)
|
||||
api/ @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
api/core/workflow/graph_engine/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/runtime/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/graph_events/ @laipz8200 @QuantumGhost
|
||||
api/core/workflow/node_events/ @laipz8200 @QuantumGhost
|
||||
api/core/model_runtime/ @laipz8200 @QuantumGhost
|
||||
|
||||
# Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM)
|
||||
api/core/workflow/nodes/agent/ @Nov1c444
|
||||
api/core/workflow/nodes/iteration/ @Nov1c444
|
||||
api/core/workflow/nodes/loop/ @Nov1c444
|
||||
api/core/workflow/nodes/llm/ @Nov1c444
|
||||
|
||||
# Backend - RAG (Retrieval Augmented Generation)
|
||||
api/core/rag/ @JohnJyong
|
||||
api/services/rag_pipeline/ @JohnJyong
|
||||
api/services/dataset_service.py @JohnJyong
|
||||
api/services/knowledge_service.py @JohnJyong
|
||||
api/services/external_knowledge_service.py @JohnJyong
|
||||
api/services/hit_testing_service.py @JohnJyong
|
||||
api/services/metadata_service.py @JohnJyong
|
||||
api/services/vector_service.py @JohnJyong
|
||||
api/services/entities/knowledge_entities/ @JohnJyong
|
||||
api/services/entities/external_knowledge_entities/ @JohnJyong
|
||||
api/controllers/console/datasets/ @JohnJyong
|
||||
api/controllers/service_api/dataset/ @JohnJyong
|
||||
api/models/dataset.py @JohnJyong
|
||||
api/tasks/rag_pipeline/ @JohnJyong
|
||||
api/tasks/add_document_to_index_task.py @JohnJyong
|
||||
api/tasks/batch_clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_document_task.py @JohnJyong
|
||||
api/tasks/clean_notion_document_task.py @JohnJyong
|
||||
api/tasks/document_indexing_task.py @JohnJyong
|
||||
api/tasks/document_indexing_sync_task.py @JohnJyong
|
||||
api/tasks/document_indexing_update_task.py @JohnJyong
|
||||
api/tasks/duplicate_document_indexing_task.py @JohnJyong
|
||||
api/tasks/recover_document_indexing_task.py @JohnJyong
|
||||
api/tasks/remove_document_from_index_task.py @JohnJyong
|
||||
api/tasks/retry_document_indexing_task.py @JohnJyong
|
||||
api/tasks/sync_website_document_indexing_task.py @JohnJyong
|
||||
api/tasks/batch_create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/create_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/delete_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segment_from_index_task.py @JohnJyong
|
||||
api/tasks/disable_segments_from_index_task.py @JohnJyong
|
||||
api/tasks/enable_segment_to_index_task.py @JohnJyong
|
||||
api/tasks/enable_segments_to_index_task.py @JohnJyong
|
||||
api/tasks/clean_dataset_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_index_update_task.py @JohnJyong
|
||||
api/tasks/deal_dataset_vector_index_task.py @JohnJyong
|
||||
|
||||
# Backend - Plugins
|
||||
api/core/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/services/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/console/workspace/plugin.py @Mairuis @Yeuoly @Stream29
|
||||
api/controllers/inner_api/plugin/ @Mairuis @Yeuoly @Stream29
|
||||
api/tasks/process_tenant_plugin_autoupgrade_check_task.py @Mairuis @Yeuoly @Stream29
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
api/controllers/trigger/ @Mairuis @Yeuoly
|
||||
api/controllers/console/app/workflow_trigger.py @Mairuis @Yeuoly
|
||||
api/controllers/console/workspace/trigger_providers.py @Mairuis @Yeuoly
|
||||
api/core/trigger/ @Mairuis @Yeuoly
|
||||
api/core/app/layers/trigger_post_layer.py @Mairuis @Yeuoly
|
||||
api/services/trigger/ @Mairuis @Yeuoly
|
||||
api/models/trigger.py @Mairuis @Yeuoly
|
||||
api/fields/workflow_trigger_fields.py @Mairuis @Yeuoly
|
||||
api/repositories/workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis @Yeuoly
|
||||
api/libs/schedule_utils.py @Mairuis @Yeuoly
|
||||
api/services/workflow/scheduler.py @Mairuis @Yeuoly
|
||||
api/schedule/trigger_provider_refresh_task.py @Mairuis @Yeuoly
|
||||
api/schedule/workflow_schedule_task.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_processing_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/trigger_subscription_refresh_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_schedule_tasks.py @Mairuis @Yeuoly
|
||||
api/tasks/workflow_cfs_scheduler/ @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis @Yeuoly
|
||||
api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Async Workflow
|
||||
api/services/async_workflow_service.py @Mairuis @Yeuoly
|
||||
api/tasks/async_workflow_tasks.py @Mairuis @Yeuoly
|
||||
|
||||
# Backend - Billing
|
||||
api/services/billing_service.py @hj24 @zyssyz123
|
||||
api/controllers/console/billing/ @hj24 @zyssyz123
|
||||
|
||||
# Backend - Enterprise
|
||||
api/configs/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/enterprise/ @GarfieldDai @GareArc
|
||||
api/services/feature_service.py @GarfieldDai @GareArc
|
||||
api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
# Frontend - App - Orchestration
|
||||
web/app/components/workflow/ @iamjoel @zxhlyh
|
||||
web/app/components/workflow-app/ @iamjoel @zxhlyh
|
||||
web/app/components/app/configuration/ @iamjoel @zxhlyh
|
||||
web/app/components/app/app-publisher/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Chat
|
||||
web/app/components/base/chat/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - WebApp - Completion
|
||||
web/app/components/share/text-generation/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - App - List and Creation
|
||||
web/app/components/apps/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-dialog/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-app-modal/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/create-from-dsl-modal/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - API Documentation
|
||||
web/app/components/develop/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Logs and Annotations
|
||||
web/app/components/app/workflow-log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/log-annotation/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/annotation/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Monitoring
|
||||
web/app/(commonLayout)/app/(appDetailLayout)/\[appId\]/overview/ @JzoNgKVO @iamjoel
|
||||
web/app/components/app/overview/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - App - Settings
|
||||
web/app/components/app-sidebar/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - Hit Testing
|
||||
web/app/components/datasets/hit-testing/ @JzoNgKVO @iamjoel
|
||||
|
||||
# Frontend - RAG - List and Creation
|
||||
web/app/components/datasets/list/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/create-from-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/datasets/external-knowledge-base/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Orchestration (general rule first, specific rules below override)
|
||||
web/app/components/rag-pipeline/ @iamjoel @WTW0313
|
||||
web/app/components/rag-pipeline/components/rag-pipeline-main.tsx @iamjoel @zxhlyh
|
||||
web/app/components/rag-pipeline/store/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - RAG - Documents List
|
||||
web/app/components/datasets/documents/list.tsx @iamjoel @WTW0313
|
||||
web/app/components/datasets/documents/create-from-pipeline/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Segments List
|
||||
web/app/components/datasets/documents/detail/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - RAG - Settings
|
||||
web/app/components/datasets/settings/ @iamjoel @WTW0313
|
||||
|
||||
# Frontend - Ecosystem - Plugins
|
||||
web/app/components/plugins/ @iamjoel @zhsama
|
||||
|
||||
# Frontend - Ecosystem - Tools
|
||||
web/app/components/tools/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Ecosystem - MarketPlace
|
||||
web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
|
||||
|
||||
# Frontend - Login and Registration
|
||||
web/app/signin/ @douxc @iamjoel
|
||||
web/app/signup/ @douxc @iamjoel
|
||||
web/app/reset-password/ @douxc @iamjoel
|
||||
web/app/install/ @douxc @iamjoel
|
||||
web/app/init/ @douxc @iamjoel
|
||||
web/app/forgot-password/ @douxc @iamjoel
|
||||
web/app/account/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Service Authentication
|
||||
web/service/base.ts @douxc @iamjoel
|
||||
|
||||
# Frontend - WebApp Authentication and Access Control
|
||||
web/app/(shareLayout)/components/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-signin/ @douxc @iamjoel
|
||||
web/app/(shareLayout)/webapp-reset-password/ @douxc @iamjoel
|
||||
web/app/components/app/app-access-control/ @douxc @iamjoel
|
||||
|
||||
# Frontend - Explore Page
|
||||
web/app/components/explore/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Personal Settings
|
||||
web/app/components/header/account-setting/ @CodingOnStar @iamjoel
|
||||
web/app/components/header/account-dropdown/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Analytics
|
||||
web/app/components/base/ga/ @CodingOnStar @iamjoel
|
||||
|
||||
# Frontend - Base Components
|
||||
web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
web/utils/time.ts @iamjoel @zxhlyh
|
||||
web/utils/format.ts @iamjoel @zxhlyh
|
||||
web/utils/clipboard.ts @iamjoel @zxhlyh
|
||||
web/hooks/use-document-title.ts @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Billing and Education
|
||||
web/app/components/billing/ @iamjoel @zxhlyh
|
||||
web/app/education-apply/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Workspace
|
||||
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
|
||||
12
.github/copilot-instructions.md
vendored
12
.github/copilot-instructions.md
vendored
@ -1,12 +0,0 @@
|
||||
# Copilot Instructions
|
||||
|
||||
GitHub Copilot must follow the unified frontend testing requirements documented in `web/testing/testing.md`.
|
||||
|
||||
Key reminders:
|
||||
|
||||
- Generate tests using the mandated tech stack, naming, and code style (AAA pattern, `fireEvent`, descriptive test names, cleans up mocks).
|
||||
- Cover rendering, prop combinations, and edge cases by default; extend coverage for hooks, routing, async flows, and domain-specific components when applicable.
|
||||
- Target >95% line and branch coverage and 100% function/statement coverage.
|
||||
- Apply the project's mocking conventions for i18n, toast notifications, and Next.js utilities.
|
||||
|
||||
Any suggestions from Copilot that conflict with `web/testing/testing.md` should be revised before acceptance.
|
||||
32
.github/workflows/api-tests.yml
vendored
32
.github/workflows/api-tests.yml
vendored
@ -39,11 +39,25 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev pyrefly
|
||||
uv run pyrefly check || true
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
@ -62,7 +76,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
db
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
@ -79,19 +93,3 @@ jobs:
|
||||
|
||||
- name: Run TestContainers
|
||||
run: uv run --project api bash dev/pytest/pytest_testcontainers.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Coverage Summary
|
||||
run: |
|
||||
set -x
|
||||
# Extract coverage percentage and create a summary
|
||||
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
|
||||
|
||||
# Create a detailed coverage summary
|
||||
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
|
||||
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
9
.github/workflows/autofix.yml
vendored
9
.github/workflows/autofix.yml
vendored
@ -2,8 +2,6 @@ name: autofix.ci
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -28,17 +26,10 @@ jobs:
|
||||
# Format code
|
||||
uv run ruff format ..
|
||||
|
||||
- name: count migration progress
|
||||
run: |
|
||||
cd api
|
||||
./cnt_base.sh
|
||||
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A = db.Column($$$B)' -r '$A = mapped_column($$$B)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg -p '$A : $T = db.Column($$$B)' -r '$A : $T = mapped_column($$$B)' -l py --update-all
|
||||
# Convert Optional[T] to T | None (ignoring quoted types)
|
||||
cat > /tmp/optional-rule.yml << 'EOF'
|
||||
id: convert-optional-to-union
|
||||
|
||||
3
.github/workflows/build-push.yml
vendored
3
.github/workflows/build-push.yml
vendored
@ -4,7 +4,8 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/**"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
|
||||
61
.github/workflows/db-migration-test.yml
vendored
61
.github/workflows/db-migration-test.yml
vendored
@ -8,7 +8,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
db-migration-test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
@ -45,7 +45,7 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
db
|
||||
redis
|
||||
|
||||
- name: Prepare configs
|
||||
@ -57,60 +57,3 @@ jobs:
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env for MySQL
|
||||
run: |
|
||||
cd docker
|
||||
cp middleware.env.example middleware.env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' middleware.env
|
||||
sed -i 's/DB_HOST=db_postgres/DB_HOST=db_mysql/' middleware.env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' middleware.env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db_mysql
|
||||
redis
|
||||
|
||||
- name: Prepare configs for MySQL
|
||||
run: |
|
||||
cd api
|
||||
cp .env.example .env
|
||||
sed -i 's/DB_TYPE=postgresql/DB_TYPE=mysql/' .env
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
name: Deploy Trigger Dev
|
||||
name: Deploy RAG Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@ -7,7 +7,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/trigger-dev"
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -16,12 +16,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.TRIGGER_SSH_HOST }}
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
3
.github/workflows/expose_service_ports.sh
vendored
3
.github/workflows/expose_service_ports.sh
vendored
@ -1,7 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
yq eval '.services.weaviate.ports += ["8080:8080"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.weaviate.ports += ["50051:50051"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.qdrant.ports += ["6333:6333"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.chroma.ports += ["8000:8000"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services["milvus-standalone"].ports += ["19530:19530"]' -i docker/docker-compose.yaml
|
||||
@ -14,4 +13,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
5
.github/workflows/style.yml
vendored
5
.github/workflows/style.yml
vendored
@ -103,11 +103,6 @@ jobs:
|
||||
run: |
|
||||
pnpm run lint
|
||||
|
||||
- name: Web type check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run type-check
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@ -20,22 +20,22 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-depth: 2
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
git fetch origin "${{ github.event.before }}" || true
|
||||
git fetch origin "${{ github.sha }}" || true
|
||||
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.ts')
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
file_args=""
|
||||
for file in $changed_files; do
|
||||
filename=$(basename "$file" .ts)
|
||||
file_args="$file_args --file $filename"
|
||||
file_args="$file_args --file=$filename"
|
||||
done
|
||||
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
|
||||
echo "File arguments: $file_args"
|
||||
@ -77,15 +77,12 @@ jobs:
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: 'chore(i18n): update translations based on en-US changes'
|
||||
title: 'chore(i18n): translate i18n files and update type definitions'
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
**Triggered by:** ${{ github.sha }}
|
||||
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
branch: chore/automated-i18n-updates-${{ github.sha }}
|
||||
delete-branch: true
|
||||
branch: chore/automated-i18n-updates
|
||||
|
||||
18
.github/workflows/vdb-tests.yml
vendored
18
.github/workflows/vdb-tests.yml
vendored
@ -51,13 +51,13 @@ jobs:
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
- name: Set up Vector Store (TiDB)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: docker/tidb/docker-compose.yaml
|
||||
services: |
|
||||
tidb
|
||||
tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
- name: Check VDB Ready (TiDB)
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@ -6,9 +6,6 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# *db files
|
||||
*.db
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
@ -100,7 +97,6 @@ __pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat-schedule.db
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
@ -186,8 +182,6 @@ docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
docker/volumes/mysql/*
|
||||
docker/volumes/seekdb/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@ -240,7 +234,4 @@ scripts/stress-test/reports/
|
||||
|
||||
# mcp
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
.serena/
|
||||
9
.vscode/launch.json.template
vendored
9
.vscode/launch.json.template
vendored
@ -8,7 +8,8 @@
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development"
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
@ -27,7 +28,9 @@
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {},
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
@ -37,7 +40,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
# Windsurf Testing Rules
|
||||
|
||||
- Use `web/testing/testing.md` as the single source of truth for frontend automated testing.
|
||||
- Honor every requirement in that document when generating or accepting tests.
|
||||
- When proposing or saving tests, re-read that document and follow every requirement.
|
||||
@ -14,7 +14,7 @@ The codebase is split into:
|
||||
|
||||
- Run backend CLI commands through `uv run --project api <command>`.
|
||||
|
||||
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Backend QA gate requires passing `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before review.
|
||||
|
||||
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
|
||||
|
||||
|
||||
@ -77,8 +77,6 @@ How we prioritize:
|
||||
|
||||
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
|
||||
|
||||
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
|
||||
|
||||
#### Backend
|
||||
|
||||
For setting up the backend service, kindly refer to our detailed [instructions](https://github.com/langgenius/dify/blob/main/api/README.md) in the `api/README.md` file. This document contains step-by-step guidance to help you get the backend up and running smoothly.
|
||||
|
||||
8
Makefile
8
Makefile
@ -70,11 +70,6 @@ type-check:
|
||||
@uv run --directory api --dev basedpyright
|
||||
@echo "✅ Type check complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
|
||||
@echo "✅ Tests complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
@ -124,7 +119,6 @@ help:
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format and fix code with ruff"
|
||||
@echo " make type-check - Run type checking with basedpyright"
|
||||
@echo " make test - Run backend unit tests"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@ -134,4 +128,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check
|
||||
|
||||
26
README.md
26
README.md
@ -36,12 +36,6 @@
|
||||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Health Score" src="https://insights.linuxfoundation.org/api/badge/health-score?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Contributors" src="https://insights.linuxfoundation.org/api/badge/contributors?project=langgenius-dify"></a>
|
||||
<a href="https://insights.linuxfoundation.org/project/langgenius-dify" target="_blank">
|
||||
<img alt="LFX Active Contributors" src="https://insights.linuxfoundation.org/api/badge/active-contributors?project=langgenius-dify"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
@ -69,7 +63,7 @@ Dify is an open-source platform for developing LLM applications. Its intuitive i
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
<br/>
|
||||
</br>
|
||||
|
||||
The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
|
||||
|
||||
@ -115,15 +109,15 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
|
||||
|
||||
## Using Dify
|
||||
|
||||
- **Cloud <br/>**
|
||||
- **Cloud </br>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition<br/>**
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations<br/>**
|
||||
We provide additional enterprise-centric features. [Send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss your enterprise needs. <br/>
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=%5BGitHub%5DBusiness%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
@ -135,18 +129,8 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
## Advanced Setup
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
|
||||
- [Grafana Dashboard by @bowenliang123](https://github.com/bowenliang123/dify-grafana-dashboard)
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
|
||||
@ -27,9 +27,6 @@ FILES_URL=http://localhost:5001
|
||||
# Example: INTERNAL_FILES_URL=http://api:5001
|
||||
INTERNAL_FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# TRIGGER URL
|
||||
TRIGGER_URL=http://localhost:5001
|
||||
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
@ -72,15 +69,12 @@ REDIS_CLUSTERS_PASSWORD=
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
|
||||
CELERY_BACKEND=redis
|
||||
|
||||
# Database configuration
|
||||
DB_TYPE=postgresql
|
||||
# PostgreSQL database configuration
|
||||
DB_USERNAME=postgres
|
||||
DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
|
||||
@ -162,11 +156,9 @@ SUPABASE_URL=your-server-url
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
|
||||
# When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). Leading dots are optional.
|
||||
COOKIE_DOMAIN=
|
||||
|
||||
# Vector database configuration
|
||||
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
|
||||
VECTOR_STORE=weaviate
|
||||
# Prefix used to create collection name in vector database
|
||||
VECTOR_INDEX_NAME_PREFIX=Vector_index
|
||||
@ -176,18 +168,6 @@ WEAVIATE_ENDPOINT=http://localhost:8080
|
||||
WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
WEAVIATE_TOKENIZATION=word
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
OCEANBASE_FULLTEXT_PARSER=ik
|
||||
SEEKDB_MEMORY_LIMIT=2G
|
||||
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
@ -354,14 +334,14 @@ LINDORM_PASSWORD=admin
|
||||
LINDORM_USING_UGC=True
|
||||
LINDORM_QUERY_TIMEOUT=1
|
||||
|
||||
# AlibabaCloud MySQL Vector configuration
|
||||
ALIBABACLOUD_MYSQL_HOST=127.0.0.1
|
||||
ALIBABACLOUD_MYSQL_PORT=3306
|
||||
ALIBABACLOUD_MYSQL_USER=root
|
||||
ALIBABACLOUD_MYSQL_PASSWORD=root
|
||||
ALIBABACLOUD_MYSQL_DATABASE=dify
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
OCEANBASE_VECTOR_USER=root@test
|
||||
OCEANBASE_VECTOR_PASSWORD=difyai123456
|
||||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
OCEANBASE_ENABLE_HYBRID_SEARCH=false
|
||||
|
||||
# openGauss configuration
|
||||
OPENGAUSS_HOST=127.0.0.1
|
||||
@ -379,12 +359,6 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Comma-separated list of file extensions blocked from upload for security reasons.
|
||||
# Extensions should be lowercase without dots (e.g., exe,bat,sh,dll).
|
||||
# Empty by default to allow all file types.
|
||||
# Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll
|
||||
UPLOAD_FILE_EXTENSION_BLACKLIST=
|
||||
|
||||
# Model configuration
|
||||
MULTIMODAL_SEND_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
@ -451,9 +425,6 @@ CODE_EXECUTION_SSL_VERIFY=True
|
||||
CODE_EXECUTION_POOL_MAX_CONNECTIONS=100
|
||||
CODE_EXECUTION_POOL_MAX_KEEPALIVE_CONNECTIONS=20
|
||||
CODE_EXECUTION_POOL_KEEPALIVE_EXPIRY=5.0
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT=10
|
||||
CODE_EXECUTION_READ_TIMEOUT=60
|
||||
CODE_EXECUTION_WRITE_TIMEOUT=10
|
||||
CODE_MAX_NUMBER=9223372036854775807
|
||||
CODE_MIN_NUMBER=-9223372036854775808
|
||||
CODE_MAX_STRING_LENGTH=400000
|
||||
@ -474,9 +445,6 @@ HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY=True
|
||||
|
||||
# Webhook request configuration
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE=10485760
|
||||
|
||||
# Respect X-* headers to redirect clients
|
||||
RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
@ -532,7 +500,7 @@ API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node
|
||||
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
|
||||
# Workflow log cleanup configuration
|
||||
# Enable automatic cleanup of workflow run logs to manage database size
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=false
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED=true
|
||||
# Number of days to retain workflow run logs (default: 30 days)
|
||||
WORKFLOW_LOG_RETENTION_DAYS=30
|
||||
# Batch size for workflow log cleanup operations (default: 100)
|
||||
@ -540,7 +508,6 @@ WORKFLOW_LOG_CLEANUP_BATCH_SIZE=100
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_DEFAULT_ACTIVE_REQUESTS=0
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
|
||||
# Celery beat configuration
|
||||
@ -555,12 +522,6 @@ ENABLE_CLEAN_MESSAGES=false
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
|
||||
ENABLE_DATASETS_QUEUE_MONITOR=false
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK=true
|
||||
# Interval time in minutes for polling scheduled workflows(default: 1 min)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL=1
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100
|
||||
# Maximum number of scheduled workflows to dispatch per tick (0 for unlimited)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0
|
||||
|
||||
# Position configuration
|
||||
POSITION_TOOL_PINS=
|
||||
@ -632,9 +593,3 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Whether to encrypt dataset IDs when exporting DSL files (default: true)
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
# Maximum number of segments for dataset segments API (0 for unlimited)
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST=0
|
||||
|
||||
@ -16,7 +16,6 @@ layers =
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
runtime
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
|
||||
@ -81,6 +81,7 @@ ignore = [
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
# Agent Skill Index
|
||||
|
||||
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Platform Foundations
|
||||
|
||||
- **[Infrastructure Overview](agent_skills/infra.md)**\
|
||||
When to read this:
|
||||
|
||||
- You need to understand where a feature belongs in the architecture.
|
||||
- You’re wiring storage, Redis, vector stores, or OTEL.
|
||||
- You’re about to add CLI commands or async jobs.\
|
||||
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
|
||||
|
||||
- **[Coding Style](agent_skills/coding_style.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re writing or reviewing backend code and need the authoritative checklist.
|
||||
- You’re unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
|
||||
- You want the exact lint/type/test commands used in PRs.\
|
||||
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Plugin & Extension Development
|
||||
|
||||
- **[Plugin Systems](agent_skills/plugin.md)**\
|
||||
When to read this:
|
||||
|
||||
- You’re building or debugging a marketplace plugin.
|
||||
- You need to know how manifests, providers, daemons, and migrations fit together.\
|
||||
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
|
||||
|
||||
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
|
||||
When to read this:
|
||||
|
||||
- You must integrate OAuth for a plugin or datasource.
|
||||
- You’re handling credential encryption or refresh flows.\
|
||||
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Workflow Entry & Execution
|
||||
|
||||
- **[Trigger Concepts](agent_skills/trigger.md)**\
|
||||
When to read this:
|
||||
- You’re debugging why a workflow didn’t start.
|
||||
- You’re adding a new trigger type or hook.
|
||||
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
|
||||
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
|
||||
|
||||
______________________________________________________________________
|
||||
|
||||
## Additional Notes for Agents
|
||||
|
||||
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
|
||||
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
|
||||
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
|
||||
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
|
||||
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
|
||||
@ -15,11 +15,7 @@ FROM base AS packages
|
||||
# RUN sed -i 's@deb.debian.org@mirrors.aliyun.com@g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
g++ \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
@ -48,22 +44,14 @@ ENV PYTHONIOENCODING=utf-8
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
# Create non-root user
|
||||
ARG dify_uid=1001
|
||||
RUN groupadd -r -g ${dify_uid} dify && \
|
||||
useradd -r -u ${dify_uid} -g ${dify_uid} -s /bin/bash dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs \
|
||||
# for gmpy2 \
|
||||
libgmp-dev libmpfr-dev libmpc-dev \
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
expat libldap-2.5-0=2.5.13+dfsg-5 perl libsqlite3-0=3.40.1-2+deb12u2 zlib1g=1:1.2.13.dfsg-1 \
|
||||
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install fonts to support the use of tools like pypdfium2
|
||||
fonts-noto-cjk \
|
||||
# install a package to improve the accuracy of guessing mime type and file extension
|
||||
@ -75,29 +63,24 @@ RUN \
|
||||
|
||||
# Copy Python environment and packages
|
||||
ENV VIRTUAL_ENV=/app/api/.venv
|
||||
COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --from=packages ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
|
||||
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
|
||||
RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')"
|
||||
|
||||
# Copy source code
|
||||
COPY --chown=dify:dify . /app/api/
|
||||
|
||||
# Prepare entrypoint script
|
||||
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
|
||||
COPY . /app/api/
|
||||
|
||||
# Copy entrypoint
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
ENV NLTK_DATA=/usr/local/share/nltk_data
|
||||
|
||||
USER dify
|
||||
|
||||
ENTRYPOINT ["/bin/bash", "/entrypoint.sh"]
|
||||
|
||||
@ -15,8 +15,8 @@
|
||||
```bash
|
||||
cd ../docker
|
||||
cp middleware.env.example middleware.env
|
||||
# change the profile to mysql if you are not using postgres,change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
# change the profile to other vector database if you are not using weaviate
|
||||
docker compose -f docker-compose.middleware.yaml --profile weaviate -p dify up -d
|
||||
cd ../api
|
||||
```
|
||||
|
||||
@ -26,10 +26,6 @@
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> When the frontend and backend run on different subdomains, set COOKIE_DOMAIN to the site’s top-level domain (e.g., `example.com`). The frontend and backend must be under the same top-level domain in order to share authentication cookies.
|
||||
|
||||
1. Generate a `SECRET_KEY` in the `.env` file.
|
||||
|
||||
bash for Linux
|
||||
@ -84,7 +80,7 @@
|
||||
1. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
|
||||
|
||||
```bash
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor
|
||||
uv run celery -A app.celery worker -P gevent -c 2 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation
|
||||
```
|
||||
|
||||
Additionally, if you want to debug the celery scheduled tasks, you can run the following command in another terminal to start the beat service:
|
||||
|
||||
@ -1,115 +0,0 @@
|
||||
## Linter
|
||||
|
||||
- Always follow `.ruff.toml`.
|
||||
- Run `uv run ruff check --fix --unsafe-fixes`.
|
||||
- Keep each line under 100 characters (including spaces).
|
||||
|
||||
## Code Style
|
||||
|
||||
- `snake_case` for variables and functions.
|
||||
- `PascalCase` for classes.
|
||||
- `UPPER_CASE` for constants.
|
||||
|
||||
## Rules
|
||||
|
||||
- Use Pydantic v2 standard.
|
||||
- Use `uv` for package management.
|
||||
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
|
||||
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
|
||||
- Prefer simple functions over classes for lightweight helpers.
|
||||
- Keep files below 800 lines; split when necessary.
|
||||
- Keep code readable—no clever hacks.
|
||||
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
|
||||
|
||||
## Guiding Principles
|
||||
|
||||
- Mirror the project’s layered architecture: controller → service → core/domain.
|
||||
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
|
||||
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
|
||||
|
||||
## SQLAlchemy Patterns
|
||||
|
||||
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
|
||||
|
||||
- Open sessions with context managers:
|
||||
|
||||
```python
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
|
||||
|
||||
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
|
||||
|
||||
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
|
||||
|
||||
## Storage & External IO
|
||||
|
||||
- Access storage via `extensions.ext_storage.storage`.
|
||||
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
|
||||
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
|
||||
|
||||
## Pydantic Usage
|
||||
|
||||
- Define DTOs with Pydantic v2 models and forbid extras by default.
|
||||
|
||||
- Use `@field_validator` / `@model_validator` for domain rules.
|
||||
|
||||
- Example:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
endpoint: HttpUrl
|
||||
secret: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@field_validator("secret")
|
||||
def ensure_secret_prefix(cls, value: str) -> str:
|
||||
if not value.startswith("dify_"):
|
||||
raise ValueError("secret must start with dify_")
|
||||
return value
|
||||
```
|
||||
|
||||
## Generics & Protocols
|
||||
|
||||
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
|
||||
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
|
||||
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
|
||||
|
||||
## Error Handling & Logging
|
||||
|
||||
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
|
||||
- Declare `logger = logging.getLogger(__name__)` at module top.
|
||||
- Include tenant/app/workflow identifiers in log context.
|
||||
- Log retryable events at `warning`, terminal failures at `error`.
|
||||
|
||||
## Tooling & Checks
|
||||
|
||||
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
|
||||
- Type checks: `uv run --directory api --dev basedpyright`.
|
||||
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
|
||||
- Run all of the above before submitting your work.
|
||||
|
||||
## Controllers & Services
|
||||
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
|
||||
- Document non-obvious behaviour with concise comments.
|
||||
|
||||
## Miscellaneous
|
||||
|
||||
- Use `configs.dify_config` for configuration—never read environment variables directly.
|
||||
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
|
||||
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
|
||||
- Keep experimental scripts under `dev/`; do not ship them in production builds.
|
||||
@ -1,96 +0,0 @@
|
||||
## Configuration
|
||||
|
||||
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
|
||||
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
|
||||
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
|
||||
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
|
||||
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
|
||||
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
|
||||
|
||||
## Storage & Files
|
||||
|
||||
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
|
||||
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
|
||||
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
|
||||
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
|
||||
|
||||
## Redis & Shared State
|
||||
|
||||
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
|
||||
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
|
||||
|
||||
## Models
|
||||
|
||||
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
|
||||
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
|
||||
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
|
||||
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
|
||||
|
||||
## Vector Stores
|
||||
|
||||
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
|
||||
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
|
||||
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
|
||||
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
|
||||
|
||||
## Observability & OTEL
|
||||
|
||||
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
|
||||
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
|
||||
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
|
||||
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
|
||||
|
||||
## Ops Integrations
|
||||
|
||||
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
|
||||
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
|
||||
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
|
||||
|
||||
## Controllers, Services, Core
|
||||
|
||||
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
|
||||
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
|
||||
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
|
||||
|
||||
## Plugins, Tools, Providers
|
||||
|
||||
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
|
||||
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
|
||||
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
|
||||
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
|
||||
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
|
||||
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
|
||||
|
||||
## Async Workloads
|
||||
|
||||
see `agent_skills/trigger.md` for more detailed documentation.
|
||||
|
||||
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
|
||||
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
|
||||
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
|
||||
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
|
||||
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
|
||||
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
|
||||
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
|
||||
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
|
||||
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
|
||||
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
|
||||
|
||||
## When You Add Features
|
||||
|
||||
- Check for an existing helper or service before writing a new util.
|
||||
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
|
||||
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
|
||||
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.
|
||||
@ -1 +0,0 @@
|
||||
// TBD
|
||||
@ -1 +0,0 @@
|
||||
// TBD
|
||||
@ -1,53 +0,0 @@
|
||||
## Overview
|
||||
|
||||
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
|
||||
|
||||
## Trigger nodes
|
||||
|
||||
- `UserInput`
|
||||
- `Trigger Webhook`
|
||||
- `Trigger Schedule`
|
||||
- `Trigger Plugin`
|
||||
|
||||
### UserInput
|
||||
|
||||
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
|
||||
|
||||
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
|
||||
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
|
||||
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
|
||||
|
||||
### Trigger Webhook
|
||||
|
||||
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
|
||||
|
||||
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
|
||||
|
||||
### Trigger Schedule
|
||||
|
||||
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
|
||||
|
||||
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
|
||||
|
||||
### Trigger Plugin
|
||||
|
||||
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
|
||||
|
||||
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
|
||||
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
|
||||
|
||||
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
|
||||
|
||||
## Worker Pool / Async Task
|
||||
|
||||
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
|
||||
|
||||
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
|
||||
|
||||
## Debug Strategy
|
||||
|
||||
Dify divided users into 2 groups: builders / end users.
|
||||
|
||||
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
|
||||
|
||||
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.
|
||||
21
api/app.py
21
api/app.py
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def is_db_command() -> bool:
|
||||
def is_db_command():
|
||||
if len(sys.argv) > 1 and sys.argv[0].endswith("flask") and sys.argv[1] == "db":
|
||||
return True
|
||||
return False
|
||||
@ -13,12 +13,23 @@ if is_db_command():
|
||||
|
||||
app = create_migrations_app()
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# See `api/docker/entrypoint.sh` (lines 33 and 47) for details.
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# For third-party library patching, refer to `gunicorn.conf.py` and `celery_entrypoint.py`.
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@ -18,7 +18,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
dify_app.config["RESTX_INCLUDE_ALL_MODELS"] = True
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
@ -51,7 +50,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
ext_import_modules,
|
||||
ext_logging,
|
||||
@ -76,7 +74,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
ext_import_modules,
|
||||
ext_orjson,
|
||||
ext_forward_refs,
|
||||
ext_set_secretkey,
|
||||
ext_compress,
|
||||
ext_code_based_extension,
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
for pattern in "Base" "TypeBase"; do
|
||||
printf "%s " "$pattern"
|
||||
grep "($pattern):" -r --include='*.py' --exclude-dir=".venv" --exclude-dir="tests" . | wc -l
|
||||
done
|
||||
@ -15,12 +15,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -321,8 +321,6 @@ def migrate_knowledge_vector_database():
|
||||
)
|
||||
|
||||
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
|
||||
if not datasets.items:
|
||||
break
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
|
||||
@ -1229,55 +1227,6 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_system_trigger_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup system trigger oauth client
|
||||
"""
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerOAuthSystemClient
|
||||
|
||||
provider_id = TriggerProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
deleted_count = (
|
||||
db.session.query(TriggerOAuthSystemClient)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
oauth_client = TriggerOAuthSystemClient(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
encrypted_oauth_params=oauth_client_params,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
"""
|
||||
Find draft variables that reference non-existent apps.
|
||||
@ -1471,10 +1420,7 @@ def setup_datasource_oauth_client(provider, client_params):
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
@click.option(
|
||||
"--environment", prompt=True, help="the environment to transform datasource credentials", default="online"
|
||||
)
|
||||
def transform_datasource_credentials(environment: str):
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
@ -1485,14 +1431,9 @@ def transform_datasource_credentials(environment: str):
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
if environment == "online":
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
else:
|
||||
notion_plugin_unique_identifier = None
|
||||
firecrawl_plugin_unique_identifier = None
|
||||
jina_plugin_unique_identifier = None
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
@ -1580,14 +1521,6 @@ def transform_datasource_credentials(environment: str):
|
||||
auth_count = 0
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not firecrawl_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
@ -1643,14 +1576,6 @@ def transform_datasource_credentials(environment: str):
|
||||
auth_count = 0
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
if not jina_tenant_credential.credentials:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Skipping jina credential for tenant {tenant_id} due to missing credentials.",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
# get credential api key
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
@ -1658,7 +1583,7 @@ def transform_datasource_credentials(environment: str):
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jinareader",
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
|
||||
@ -73,14 +73,14 @@ class AppExecutionConfig(BaseSettings):
|
||||
description="Maximum allowed execution time for the application in seconds",
|
||||
default=1200,
|
||||
)
|
||||
APP_DEFAULT_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Default number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_MAX_ACTIVE_REQUESTS: NonNegativeInt = Field(
|
||||
description="Maximum number of concurrent active requests per app (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
APP_DAILY_RATE_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of requests per app per day",
|
||||
default=5000,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
@ -174,33 +174,6 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TriggerConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for trigger
|
||||
"""
|
||||
|
||||
WEBHOOK_REQUEST_BODY_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size for webhook request bodies in bytes",
|
||||
default=10485760,
|
||||
)
|
||||
|
||||
|
||||
class AsyncWorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for async workflow
|
||||
"""
|
||||
|
||||
ASYNC_WORKFLOW_SCHEDULER_GRANULARITY: int = Field(
|
||||
description="Granularity for async workflow scheduler, "
|
||||
"sometime, few users could block the queue due to some time-consuming tasks, "
|
||||
"to avoid this, workflow can be suspended if needed, to achieve"
|
||||
"this, a time-based checker is required, every granularity seconds, "
|
||||
"the checker will check the workflow queue and suspend the workflow",
|
||||
default=120,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class PluginConfig(BaseSettings):
|
||||
"""
|
||||
Plugin configs
|
||||
@ -216,11 +189,6 @@ class PluginConfig(BaseSettings):
|
||||
default="plugin-api-key",
|
||||
)
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
PLUGIN_REMOTE_INSTALL_HOST: str = Field(
|
||||
@ -290,8 +258,6 @@ class EndpointConfig(BaseSettings):
|
||||
description="Template url for endpoint plugin", default="http://localhost:5002/e/{hook_id}"
|
||||
)
|
||||
|
||||
TRIGGER_URL: str = Field(description="Template url for triggers", default="http://localhost:5001")
|
||||
|
||||
|
||||
class FileAccessConfig(BaseSettings):
|
||||
"""
|
||||
@ -360,42 +326,12 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field(
|
||||
description=(
|
||||
"Comma-separated list of file extensions that are blocked from upload. "
|
||||
"Extensions should be lowercase without dots (e.g., 'exe,bat,sh,dll'). "
|
||||
"Empty by default to allow all file types."
|
||||
),
|
||||
validation_alias=AliasChoices("UPLOAD_FILE_EXTENSION_BLACKLIST"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def UPLOAD_FILE_EXTENSION_BLACKLIST(self) -> set[str]:
|
||||
"""
|
||||
Parse and return the blacklist as a set of lowercase extensions.
|
||||
Returns an empty set if no blacklist is configured.
|
||||
"""
|
||||
if not self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST:
|
||||
return set()
|
||||
return {
|
||||
ext.strip().lower().strip(".")
|
||||
for ext in self.inner_UPLOAD_FILE_EXTENSION_BLACKLIST.split(",")
|
||||
if ext.strip()
|
||||
}
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
HTTP-related configurations for the application
|
||||
"""
|
||||
|
||||
COOKIE_DOMAIN: str = Field(
|
||||
description="Explicit cookie domain for console/service cookies when sharing across subdomains",
|
||||
default="",
|
||||
)
|
||||
|
||||
API_COMPRESSION_ENABLED: bool = Field(
|
||||
description="Enable or disable gzip compression for HTTP responses",
|
||||
default=False,
|
||||
@ -426,11 +362,11 @@ class HttpConfig(BaseSettings):
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=600
|
||||
ge=1, description="Maximum read timeout in seconds for HTTP requests", default=60
|
||||
)
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=600
|
||||
ge=1, description="Maximum write timeout in seconds for HTTP requests", default=20
|
||||
)
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
@ -607,7 +543,7 @@ class UpdateConfig(BaseSettings):
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 1000 KiB
|
||||
# 100KB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
@ -835,7 +771,7 @@ class MailConfig(BaseSettings):
|
||||
|
||||
MAIL_TEMPLATING_TIMEOUT: int = Field(
|
||||
description="""
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Timeout for email templating in seconds. Used to prevent infinite loops in malicious templates.
|
||||
Only available in sandbox mode.""",
|
||||
default=3,
|
||||
)
|
||||
@ -974,11 +910,6 @@ class DataSetConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
DATASET_MAX_SEGMENTS_PER_REQUEST: NonNegativeInt = Field(
|
||||
description="Maximum number of segments for dataset segments API (0 for unlimited)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceConfig(BaseSettings):
|
||||
"""
|
||||
@ -1054,44 +985,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK: bool = Field(
|
||||
description="Enable workflow schedule poller task",
|
||||
default=True,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_INTERVAL: int = Field(
|
||||
description="Workflow schedule poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE: int = Field(
|
||||
description="Maximum number of schedules to process in each poll batch",
|
||||
default=100,
|
||||
)
|
||||
WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK: int = Field(
|
||||
description="Maximum schedules to dispatch per tick (0=unlimited, circuit breaker)",
|
||||
default=0,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
default=True,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_INTERVAL: int = Field(
|
||||
description="Trigger provider refresh poller interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
TRIGGER_PROVIDER_REFRESH_BATCH_SIZE: int = Field(
|
||||
description="Max trigger subscriptions to process per tick",
|
||||
default=200,
|
||||
)
|
||||
TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive credential refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS: int = Field(
|
||||
description="Proactive subscription refresh threshold in seconds",
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
|
||||
class PositionConfig(BaseSettings):
|
||||
@ -1190,7 +1083,7 @@ class AccountConfig(BaseSettings):
|
||||
|
||||
|
||||
class WorkflowLogConfig(BaseSettings):
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=False, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_CLEANUP_ENABLED: bool = Field(default=True, description="Enable workflow run log cleanup")
|
||||
WORKFLOW_LOG_RETENTION_DAYS: int = Field(default=30, description="Retention days for workflow run logs")
|
||||
WORKFLOW_LOG_CLEANUP_BATCH_SIZE: int = Field(
|
||||
default=100, description="Batch size for workflow run log cleanup operations"
|
||||
@ -1209,21 +1102,12 @@ class SwaggerUIConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY: int = Field(
|
||||
description="Number of tasks allowed to be delivered concurrently from isolated queue per tenant",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
DataSetConfig,
|
||||
@ -1242,7 +1126,6 @@ class FeatureConfig(
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
|
||||
@ -18,7 +18,6 @@ from .storage.opendal_storage_config import OpenDALStorageConfig
|
||||
from .storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from .vdb.alibabacloud_mysql_config import AlibabaCloudMySQLConfig
|
||||
from .vdb.analyticdb_config import AnalyticdbConfig
|
||||
from .vdb.baidu_vector_config import BaiduVectorDBConfig
|
||||
from .vdb.chroma_config import ChromaConfig
|
||||
@ -105,12 +104,6 @@ class KeywordStoreConfig(BaseSettings):
|
||||
|
||||
|
||||
class DatabaseConfig(BaseSettings):
|
||||
# Database type selector
|
||||
DB_TYPE: Literal["postgresql", "mysql", "oceanbase"] = Field(
|
||||
description="Database type to use. OceanBase is MySQL-compatible.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
DB_HOST: str = Field(
|
||||
description="Hostname or IP address of the database server.",
|
||||
default="localhost",
|
||||
@ -146,12 +139,12 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
return "postgresql" if self.DB_TYPE == "postgresql" else "mysql+pymysql"
|
||||
SQLALCHEMY_DATABASE_URI_SCHEME: str = Field(
|
||||
description="Database URI scheme for SQLAlchemy connection.",
|
||||
default="postgresql",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
db_extras = (
|
||||
@ -204,21 +197,21 @@ class DatabaseConfig(BaseSettings):
|
||||
default=os.cpu_count() or 1,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||
# Parse DB_EXTRAS for 'options'
|
||||
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
|
||||
options = db_extras_dict.get("options", "")
|
||||
connect_args = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
# Always include timezone
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
# Merge user options and timezone
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
return {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
@ -337,7 +330,6 @@ class MiddlewareConfig(
|
||||
ClickzettaConfig,
|
||||
HuaweiCloudConfig,
|
||||
MilvusConfig,
|
||||
AlibabaCloudMySQLConfig,
|
||||
MyScaleConfig,
|
||||
OpenSearchConfig,
|
||||
OracleConfig,
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AlibabaCloudMySQLConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for AlibabaCloud MySQL vector database
|
||||
"""
|
||||
|
||||
ALIBABACLOUD_MYSQL_HOST: str = Field(
|
||||
description="Hostname or IP address of the AlibabaCloud MySQL server (e.g., 'localhost' or 'mysql.aliyun.com')",
|
||||
default="localhost",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PORT: PositiveInt = Field(
|
||||
description="Port number on which the AlibabaCloud MySQL server is listening (default is 3306)",
|
||||
default=3306,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_USER: str = Field(
|
||||
description="Username for authenticating with AlibabaCloud MySQL (default is 'root')",
|
||||
default="root",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_PASSWORD: str = Field(
|
||||
description="Password for authenticating with AlibabaCloud MySQL (default is an empty string)",
|
||||
default="",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DATABASE: str = Field(
|
||||
description="Name of the AlibabaCloud MySQL database to connect to (default is 'dify')",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_MAX_CONNECTION: PositiveInt = Field(
|
||||
description="Maximum number of connections in the connection pool",
|
||||
default=5,
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_CHARSET: str = Field(
|
||||
description="Character set for AlibabaCloud MySQL connection (default is 'utf8mb4')",
|
||||
default="utf8mb4",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_DISTANCE_FUNCTION: str = Field(
|
||||
description="Distance function used for vector similarity search in AlibabaCloud MySQL "
|
||||
"(e.g., 'cosine', 'euclidean')",
|
||||
default="cosine",
|
||||
)
|
||||
|
||||
ALIBABACLOUD_MYSQL_HNSW_M: PositiveInt = Field(
|
||||
description="Maximum number of connections per layer for HNSW vector index (default is 6, range: 3-200)",
|
||||
default=6,
|
||||
)
|
||||
@ -1,24 +1,23 @@
|
||||
from enum import StrEnum
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AuthMethod(StrEnum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
|
||||
class OpenSearchConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for OpenSearch
|
||||
"""
|
||||
|
||||
class AuthMethod(Enum):
|
||||
"""
|
||||
Authentication method for OpenSearch
|
||||
"""
|
||||
|
||||
BASIC = "basic"
|
||||
AWS_MANAGED_IAM = "aws_managed_iam"
|
||||
|
||||
OPENSEARCH_HOST: str | None = Field(
|
||||
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
|
||||
default=None,
|
||||
|
||||
@ -22,17 +22,7 @@ class WeaviateConfig(BaseSettings):
|
||||
default=True,
|
||||
)
|
||||
|
||||
WEAVIATE_GRPC_ENDPOINT: str | None = Field(
|
||||
description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
WEAVIATE_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Number of objects to be processed in a single batch operation (default is 100)",
|
||||
default=100,
|
||||
)
|
||||
|
||||
WEAVIATE_TOKENIZATION: str | None = Field(
|
||||
description="Tokenization for Weaviate (default is word)",
|
||||
default="word",
|
||||
)
|
||||
|
||||
@ -55,16 +55,3 @@ else:
|
||||
"properties",
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
# webapp
|
||||
COOKIE_NAME_WEBAPP_ACCESS_TOKEN = "webapp_access_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
||||
@ -31,9 +31,3 @@ def supported_language(lang):
|
||||
|
||||
error = f"{lang} is not a valid language."
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_valid_language(lang: str | None) -> str:
|
||||
if lang and lang in languages:
|
||||
return lang
|
||||
return languages[0]
|
||||
|
||||
File diff suppressed because one or more lines are too long
@ -9,7 +9,6 @@ if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.trigger.provider import PluginTriggerProviderController
|
||||
|
||||
|
||||
"""
|
||||
@ -42,11 +41,3 @@ datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginPro
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
plugin_trigger_providers: RecyclableContextVar[dict[str, "PluginTriggerProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers")
|
||||
)
|
||||
|
||||
plugin_trigger_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_trigger_providers_lock")
|
||||
)
|
||||
|
||||
@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException):
|
||||
code = 415
|
||||
|
||||
|
||||
class BlockedFileExtensionError(BaseHTTPException):
|
||||
error_code = "file_extension_blocked"
|
||||
description = "The file extension is blocked for security reasons."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = "too_many_files"
|
||||
description = "Only one file is allowed."
|
||||
|
||||
@ -24,7 +24,7 @@ except ImportError:
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore[assignment]
|
||||
magic = None # type: ignore
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -66,7 +66,6 @@ from .app import (
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
workflow_trigger,
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
@ -127,7 +126,6 @@ from .workspace import (
|
||||
models,
|
||||
plugin,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
)
|
||||
|
||||
@ -198,7 +196,6 @@ __all__ = [
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
"trigger_providers",
|
||||
"version",
|
||||
"website",
|
||||
"workflow",
|
||||
@ -206,6 +203,5 @@ __all__ = [
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
]
|
||||
|
||||
@ -12,10 +12,9 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@ -25,9 +24,19 @@ def admin_required(view: Callable[P, R]):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_token = extract_access_token(request)
|
||||
if not auth_token:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -38,10 +47,10 @@ def admin_required(view: Callable[P, R]):
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps")
|
||||
class InsertExploreAppListApi(Resource):
|
||||
@console_ns.doc("insert_explore_app")
|
||||
@console_ns.doc(description="Insert or update an app in the explore list")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("insert_explore_app")
|
||||
@api.doc(description="Insert or update an app in the explore list")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InsertExploreAppRequest",
|
||||
{
|
||||
"app_id": fields.String(required=True, description="Application ID"),
|
||||
@ -55,23 +64,21 @@ class InsertExploreAppListApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "App updated successfully")
|
||||
@console_ns.response(201, "App inserted successfully")
|
||||
@console_ns.response(404, "App not found")
|
||||
@api.response(200, "App updated successfully")
|
||||
@api.response(201, "App inserted successfully")
|
||||
@api.response(404, "App not found")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("desc", type=str, location="json")
|
||||
.add_argument("copyright", type=str, location="json")
|
||||
.add_argument("privacy_policy", type=str, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, location="json")
|
||||
.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("desc", type=str, location="json")
|
||||
parser.add_argument("copyright", type=str, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, location="json")
|
||||
parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json")
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
|
||||
@ -131,10 +138,10 @@ class InsertExploreAppListApi(Resource):
|
||||
|
||||
@console_ns.route("/admin/insert-explore-apps/<uuid:app_id>")
|
||||
class InsertExploreAppApi(Resource):
|
||||
@console_ns.doc("delete_explore_app")
|
||||
@console_ns.doc(description="Remove an app from the explore list")
|
||||
@console_ns.doc(params={"app_id": "Application ID to remove"})
|
||||
@console_ns.response(204, "App removed successfully")
|
||||
@api.doc("delete_explore_app")
|
||||
@api.doc(description="Remove an app from the explore list")
|
||||
@api.doc(params={"app_id": "Application ID to remove"})
|
||||
@api.response(204, "App removed successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def delete(self, app_id):
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import flask_restx
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import select
|
||||
@ -7,12 +8,12 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from . import api, console_ns
|
||||
from .wraps import account_initialization_required, setup_required
|
||||
|
||||
api_key_fields = {
|
||||
"id": fields.String,
|
||||
@ -24,12 +25,6 @@ api_key_fields = {
|
||||
|
||||
api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")}
|
||||
|
||||
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
|
||||
|
||||
api_key_list_model = console_ns.model(
|
||||
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
|
||||
)
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
if resource_model == App:
|
||||
@ -58,13 +53,11 @@ class BaseApiKeyListResource(Resource):
|
||||
token_prefix: str | None = None
|
||||
max_keys = 10
|
||||
|
||||
@marshal_with(api_key_list_model)
|
||||
@marshal_with(api_key_list)
|
||||
def get(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
keys = db.session.scalars(
|
||||
select(ApiToken).where(
|
||||
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
|
||||
@ -72,13 +65,14 @@ class BaseApiKeyListResource(Resource):
|
||||
).all()
|
||||
return {"items": keys}
|
||||
|
||||
@marshal_with(api_key_item_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(api_key_fields)
|
||||
def post(self, resource_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
resource_id = str(resource_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
|
||||
@ -95,7 +89,7 @@ class BaseApiKeyListResource(Resource):
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_tenant_id
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
api_token.token = key
|
||||
api_token.type = self.resource_type
|
||||
db.session.add(api_token)
|
||||
@ -110,11 +104,13 @@ class BaseApiKeyResource(Resource):
|
||||
resource_model: type | None = None
|
||||
resource_id_field: str | None = None
|
||||
|
||||
def delete(self, resource_id: str, api_key_id: str):
|
||||
def delete(self, resource_id, api_key_id):
|
||||
assert self.resource_id_field is not None, "resource_id_field must be set"
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
resource_id = str(resource_id)
|
||||
api_key_id = str(api_key_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
@ -139,23 +135,28 @@ class BaseApiKeyResource(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_app_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
@api.doc("get_app_api_keys")
|
||||
@api.doc(description="Get all API keys for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
"""Get all API keys for an app"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_app_api_key")
|
||||
@console_ns.doc(description="Create a new API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
@api.doc("create_app_api_key")
|
||||
@api.doc(description="Create a new API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
"""Create a new API key for an app"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -164,14 +165,19 @@ class AppApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class AppApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc("delete_app_api_key")
|
||||
@console_ns.doc(description="Delete an API key for an app")
|
||||
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@api.doc("delete_app_api_key")
|
||||
@api.doc(description="Delete an API key for an app")
|
||||
@api.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for an app"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "app"
|
||||
resource_model = App
|
||||
resource_id_field = "app_id"
|
||||
@ -179,23 +185,28 @@ class AppApiKeyResource(BaseApiKeyResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys")
|
||||
class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
@console_ns.doc("get_dataset_api_keys")
|
||||
@console_ns.doc(description="Get all API keys for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Success", api_key_list_model)
|
||||
def get(self, resource_id): # type: ignore
|
||||
@api.doc("get_dataset_api_keys")
|
||||
@api.doc(description="Get all API keys for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(200, "Success", api_key_list)
|
||||
def get(self, resource_id):
|
||||
"""Get all API keys for a dataset"""
|
||||
return super().get(resource_id)
|
||||
|
||||
@console_ns.doc("create_dataset_api_key")
|
||||
@console_ns.doc(description="Create a new API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID"})
|
||||
@console_ns.response(201, "API key created successfully", api_key_item_model)
|
||||
@console_ns.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id): # type: ignore
|
||||
@api.doc("create_dataset_api_key")
|
||||
@api.doc(description="Create a new API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID"})
|
||||
@api.response(201, "API key created successfully", api_key_fields)
|
||||
@api.response(400, "Maximum keys exceeded")
|
||||
def post(self, resource_id):
|
||||
"""Create a new API key for a dataset"""
|
||||
return super().post(resource_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
@ -204,14 +215,19 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>")
|
||||
class DatasetApiKeyResource(BaseApiKeyResource):
|
||||
@console_ns.doc("delete_dataset_api_key")
|
||||
@console_ns.doc(description="Delete an API key for a dataset")
|
||||
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@console_ns.response(204, "API key deleted successfully")
|
||||
@api.doc("delete_dataset_api_key")
|
||||
@api.doc(description="Delete an API key for a dataset")
|
||||
@api.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
|
||||
@api.response(204, "API key deleted successfully")
|
||||
def delete(self, resource_id, api_key_id):
|
||||
"""Delete an API key for a dataset"""
|
||||
return super().delete(resource_id, api_key_id)
|
||||
|
||||
def after_request(self, resp):
|
||||
resp.headers["Access-Control-Allow-Origin"] = "*"
|
||||
resp.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
return resp
|
||||
|
||||
resource_type = "dataset"
|
||||
resource_model = Dataset
|
||||
resource_id_field = "dataset_id"
|
||||
|
||||
@ -1,39 +1,35 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
app_mode: str = Field(..., description="Application mode")
|
||||
model_mode: str = Field(..., description="Model mode")
|
||||
has_context: str = Field(default="true", description="Whether has context")
|
||||
model_name: str = Field(..., description="Model name")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
AdvancedPromptTemplateQuery.__name__,
|
||||
AdvancedPromptTemplateQuery.model_json_schema(ref_template="#/definitions/{model}"),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/app/prompt-templates")
|
||||
class AdvancedPromptTemplateList(Resource):
|
||||
@console_ns.doc("get_advanced_prompt_templates")
|
||||
@console_ns.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@console_ns.expect(console_ns.models[AdvancedPromptTemplateQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_advanced_prompt_templates")
|
||||
@api.doc(description="Get advanced prompt templates based on app mode and model configuration")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("app_mode", type=str, required=True, location="args", help="Application mode")
|
||||
.add_argument("model_mode", type=str, required=True, location="args", help="Model mode")
|
||||
.add_argument("has_context", type=str, default="true", location="args", help="Whether has context")
|
||||
.add_argument("model_name", type=str, required=True, location="args", help="Model name")
|
||||
)
|
||||
@api.response(
|
||||
200, "Prompt templates retrieved successfully", fields.List(fields.Raw(description="Prompt template data"))
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("app_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("model_mode", type=str, required=True, location="args")
|
||||
parser.add_argument("has_context", type=str, required=False, default="true", location="args")
|
||||
parser.add_argument("model_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
return AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.helper import uuid_value
|
||||
@ -8,29 +8,29 @@ from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from services.agent_service import AgentService
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=uuid_value, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=uuid_value, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/agent/logs")
|
||||
class AgentLogApi(Resource):
|
||||
@console_ns.doc("get_agent_logs")
|
||||
@console_ns.doc(description="Get agent execution logs for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(parser)
|
||||
@console_ns.response(
|
||||
200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries"))
|
||||
@api.doc("get_agent_logs")
|
||||
@api.doc(description="Get agent execution logs for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("message_id", type=str, required=True, location="args", help="Message UUID")
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation UUID")
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.response(200, "Agent logs retrieved successfully", fields.List(fields.Raw(description="Agent log entries")))
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT])
|
||||
def get(self, app_model):
|
||||
"""Get agent logs"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=uuid_value, required=True, location="args")
|
||||
parser.add_argument("conversation_id", type=uuid_value, required=True, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"])
|
||||
|
||||
@ -1,34 +1,33 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
annotation_hit_history_fields,
|
||||
build_annotation_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@console_ns.doc("annotation_reply_action")
|
||||
@console_ns.doc(description="Enable or disable annotation reply for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("annotation_reply_action")
|
||||
@api.doc(description="Enable or disable annotation reply for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "action": "Action to perform (enable/disable)"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AnnotationReplyActionRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold for annotation matching"),
|
||||
@ -37,21 +36,21 @@ class AnnotationReplyActionApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Action completed successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(200, "Action completed successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id, action: Literal["enable", "disable"]):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_id)
|
||||
@ -62,16 +61,18 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-setting")
|
||||
class AppAnnotationSettingDetailApi(Resource):
|
||||
@console_ns.doc("get_annotation_setting")
|
||||
@console_ns.doc(description="Get annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Annotation settings retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("get_annotation_setting")
|
||||
@api.doc(description="Get annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotation settings retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id)
|
||||
return result, 200
|
||||
@ -79,11 +80,11 @@ class AppAnnotationSettingDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>")
|
||||
class AppAnnotationSettingUpdateApi(Resource):
|
||||
@console_ns.doc("update_annotation_setting")
|
||||
@console_ns.doc(description="Update annotation settings for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_annotation_setting")
|
||||
@api.doc(description="Update annotation settings for an app")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_setting_id": "Annotation setting ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AnnotationSettingUpdateRequest",
|
||||
{
|
||||
"score_threshold": fields.Float(required=True, description="Score threshold"),
|
||||
@ -92,17 +93,20 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Settings updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(200, "Settings updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id, annotation_setting_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_setting_id = str(annotation_setting_id)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args)
|
||||
@ -111,17 +115,19 @@ class AppAnnotationSettingUpdateApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>")
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@console_ns.doc("get_annotation_reply_action_status")
|
||||
@console_ns.doc(description="Get status of annotation reply action job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("get_annotation_reply_action_status")
|
||||
@api.doc(description="Get status of annotation reply action job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID", "action": "Action type"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id, action):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@ -139,22 +145,24 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class AnnotationApi(Resource):
|
||||
@console_ns.doc("list_annotations")
|
||||
@console_ns.doc(description="Get annotations for an app with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
@api.doc("list_annotations")
|
||||
@api.doc(description="Get annotations for an app with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
.add_argument("keyword", type=str, location="args", default="", help="Search keyword")
|
||||
)
|
||||
@console_ns.response(200, "Annotations retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(200, "Annotations retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@ -170,48 +178,45 @@ class AnnotationApi(Resource):
|
||||
}
|
||||
return response, 200
|
||||
|
||||
@console_ns.doc("create_annotation")
|
||||
@console_ns.doc(description="Create a new annotation for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("create_annotation")
|
||||
@api.doc(description="Create a new annotation for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID (optional)"),
|
||||
"question": fields.String(description="Question text (required when message_id not provided)"),
|
||||
"answer": fields.String(description="Answer text (use 'answer' or 'content')"),
|
||||
"content": fields.String(description="Content text (use 'answer' or 'content')"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply data"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(201, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@marshal_with(annotation_fields)
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
.add_argument("question", required=False, type=str, location="json")
|
||||
.add_argument("answer", required=False, type=str, location="json")
|
||||
.add_argument("content", required=False, type=str, location="json")
|
||||
.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id)
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id)
|
||||
return annotation
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
|
||||
# Use request.args.getlist to get annotation_ids array directly
|
||||
@ -236,51 +241,46 @@ class AnnotationApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
class AnnotationExportApi(Resource):
|
||||
@console_ns.doc("export_annotations")
|
||||
@console_ns.doc(description="Export all annotations for an app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Annotations exported successfully",
|
||||
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("export_annotations")
|
||||
@api.doc(description="Export all annotations for an app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Annotations exported successfully", fields.List(fields.Nested(annotation_fields)))
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
|
||||
response = {"data": marshal(annotation_list, annotation_fields)}
|
||||
return response, 200
|
||||
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json")
|
||||
.add_argument("answer", required=True, type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@console_ns.doc("update_delete_annotation")
|
||||
@console_ns.doc(description="Update or delete an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
|
||||
@console_ns.response(204, "Annotation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.expect(parser)
|
||||
@api.doc("update_delete_annotation")
|
||||
@api.doc(description="Update or delete an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.response(200, "Annotation updated successfully", annotation_fields)
|
||||
@api.response(204, "Annotation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id)
|
||||
return annotation
|
||||
@ -288,8 +288,10 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
annotation_id = str(annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
|
||||
@ -298,18 +300,20 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
class AnnotationBatchImportApi(Resource):
|
||||
@console_ns.doc("batch_import_annotations")
|
||||
@console_ns.doc(description="Batch import annotations from CSV file")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Batch import started successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "No file uploaded or too many files")
|
||||
@api.doc("batch_import_annotations")
|
||||
@api.doc(description="Batch import annotations from CSV file")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Batch import started successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "No file uploaded or too many files")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_id = str(app_id)
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
@ -328,17 +332,19 @@ class AnnotationBatchImportApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>")
|
||||
class AnnotationBatchImportStatusApi(Resource):
|
||||
@console_ns.doc("get_batch_import_status")
|
||||
@console_ns.doc(description="Get status of batch import job")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@console_ns.response(200, "Job status retrieved successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("get_batch_import_status")
|
||||
@api.doc(description="Get status of batch import job")
|
||||
@api.doc(params={"app_id": "Application ID", "job_id": "Job ID"})
|
||||
@api.response(200, "Job status retrieved successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@edit_permission_required
|
||||
def get(self, app_id, job_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
job_id = str(job_id)
|
||||
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
@ -355,32 +361,25 @@ class AnnotationBatchImportStatusApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories")
|
||||
class AnnotationHitHistoryListApi(Resource):
|
||||
@console_ns.doc("list_annotation_hit_histories")
|
||||
@console_ns.doc(description="Get hit histories for an annotation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser()
|
||||
@api.doc("list_annotation_hit_histories")
|
||||
@api.doc(description="Get hit histories for an annotation")
|
||||
@api.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size")
|
||||
)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Hit histories retrieved successfully",
|
||||
console_ns.model(
|
||||
"AnnotationHitHistoryList",
|
||||
{
|
||||
"data": fields.List(
|
||||
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
|
||||
)
|
||||
},
|
||||
),
|
||||
@api.response(
|
||||
200, "Hit histories retrieved successfully", fields.List(fields.Nested(annotation_hit_history_fields))
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_id, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
app_id = str(app_id)
|
||||
|
||||
@ -1,295 +1,96 @@
|
||||
import uuid
|
||||
from typing import Literal
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, abort
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
enterprise_license_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
deleted_tool_fields,
|
||||
model_config_fields,
|
||||
model_config_partial_fields,
|
||||
site_fields,
|
||||
tag_fields,
|
||||
)
|
||||
from fields.workflow_fields import workflow_partial_fields as _workflow_partial_fields_dict
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
|
||||
from libs.login import login_required
|
||||
from libs.validators import validate_description_length
|
||||
from models import App, Workflow
|
||||
from models import Account, App
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
if not items:
|
||||
return None
|
||||
|
||||
try:
|
||||
return [str(uuid.UUID(item)) for item in items]
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
description: str | None = Field(default=None, description="Description for the copied app")
|
||||
icon_type: str | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("description")
|
||||
@classmethod
|
||||
def validate_description(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return validate_description_length(value)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
|
||||
|
||||
|
||||
class AppNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="Name to check")
|
||||
|
||||
|
||||
class AppIconPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon data")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
|
||||
class AppSiteStatusPayload(BaseModel):
|
||||
enable_site: bool = Field(..., description="Enable or disable site")
|
||||
|
||||
|
||||
class AppApiStatusPayload(BaseModel):
|
||||
enable_api: bool = Field(..., description="Enable or disable API")
|
||||
|
||||
|
||||
class AppTracePayload(BaseModel):
|
||||
enabled: bool = Field(..., description="Enable or disable tracing")
|
||||
tracing_provider: str = Field(..., description="Tracing provider")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(AppListQuery)
|
||||
reg(CreateAppPayload)
|
||||
reg(UpdateAppPayload)
|
||||
reg(CopyAppPayload)
|
||||
reg(AppExportQuery)
|
||||
reg(AppNamePayload)
|
||||
reg(AppIconPayload)
|
||||
reg(AppSiteStatusPayload)
|
||||
reg(AppApiStatusPayload)
|
||||
reg(AppTracePayload)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base models first
|
||||
tag_model = console_ns.model("Tag", tag_fields)
|
||||
|
||||
workflow_partial_model = console_ns.model("WorkflowPartial", _workflow_partial_fields_dict)
|
||||
|
||||
model_config_model = console_ns.model("ModelConfig", model_config_fields)
|
||||
|
||||
model_config_partial_model = console_ns.model("ModelConfigPartial", model_config_partial_fields)
|
||||
|
||||
deleted_tool_model = console_ns.model("DeletedTool", deleted_tool_fields)
|
||||
|
||||
site_model = console_ns.model("Site", site_fields)
|
||||
|
||||
app_partial_model = console_ns.model(
|
||||
"AppPartial",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"max_active_requests": fields.Raw(),
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"model_config": fields.Nested(model_config_partial_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"access_mode": fields.String,
|
||||
"create_user_name": fields.String,
|
||||
"author_name": fields.String,
|
||||
"has_draft_trigger": fields.Boolean,
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_model = console_ns.model(
|
||||
"AppDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
},
|
||||
)
|
||||
|
||||
app_detail_with_site_model = console_ns.model(
|
||||
"AppDetailWithSite",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"model_config": fields.Nested(model_config_model, attribute="app_model_config", allow_null=True),
|
||||
"workflow": fields.Nested(workflow_partial_model, allow_null=True),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"max_active_requests": fields.Integer,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"deleted_tools": fields.List(fields.Nested(deleted_tool_model)),
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_model)),
|
||||
"site": fields.Nested(site_model),
|
||||
},
|
||||
)
|
||||
|
||||
app_pagination_model = console_ns.model(
|
||||
"AppPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
@console_ns.doc("list_apps")
|
||||
@console_ns.doc(description="Get list of applications with pagination and filtering")
|
||||
@console_ns.expect(console_ns.models[AppListQuery.__name__])
|
||||
@console_ns.response(200, "Success", app_pagination_model)
|
||||
@api.doc("list_apps")
|
||||
@api.doc(description="Get list of applications with pagination and filtering")
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("page", type=int, location="args", help="Page number (1-99999)", default=1)
|
||||
.add_argument("limit", type=int, location="args", help="Page size (1-100)", default=20)
|
||||
.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"],
|
||||
default="all",
|
||||
help="App mode filter",
|
||||
)
|
||||
.add_argument("name", type=str, location="args", help="Filter by app name")
|
||||
.add_argument("tag_ids", type=str, location="args", help="Comma-separated tag IDs")
|
||||
.add_argument("is_created_by_me", type=bool, location="args", help="Filter by creator")
|
||||
)
|
||||
@api.response(200, "Success", app_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
def uuid_list(value):
|
||||
try:
|
||||
return [str(uuid.UUID(v)) for v in value.split(",")]
|
||||
except ValueError:
|
||||
abort(400, message="Invalid UUID format in tag_ids.")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
type=str,
|
||||
choices=[
|
||||
"completion",
|
||||
"chat",
|
||||
"advanced-chat",
|
||||
"workflow",
|
||||
"agent-chat",
|
||||
"channel",
|
||||
"all",
|
||||
],
|
||||
default="all",
|
||||
location="args",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument("name", type=str, location="args", required=False)
|
||||
parser.add_argument("tag_ids", type=uuid_list, location="args", required=False)
|
||||
parser.add_argument("is_created_by_me", type=inputs.boolean, location="args", required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get app list
|
||||
app_service = AppService()
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_tenant_id, args_dict)
|
||||
app_pagination = app_service.get_paginate_apps(current_user.id, current_user.current_tenant_id, args)
|
||||
if not app_pagination:
|
||||
return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False}
|
||||
|
||||
@ -303,72 +104,71 @@ class AppListApi(Resource):
|
||||
if str(app.id) in res:
|
||||
app.access_mode = res[str(app.id)].access_mode
|
||||
|
||||
workflow_capable_app_ids = [
|
||||
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
|
||||
]
|
||||
draft_trigger_app_ids: set[str] = set()
|
||||
if workflow_capable_app_ids:
|
||||
draft_workflows = (
|
||||
db.session.execute(
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
trigger_node_types = {
|
||||
NodeType.TRIGGER_WEBHOOK,
|
||||
NodeType.TRIGGER_SCHEDULE,
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
return marshal(app_pagination, app_pagination_fields), 200
|
||||
|
||||
for app in app_pagination.items:
|
||||
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
|
||||
|
||||
return marshal(app_pagination, app_pagination_model), 200
|
||||
|
||||
@console_ns.doc("create_app")
|
||||
@console_ns.doc(description="Create a new application")
|
||||
@console_ns.expect(console_ns.models[CreateAppPayload.__name__])
|
||||
@console_ns.response(201, "App created successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.doc("create_app")
|
||||
@api.doc(description="Create a new application")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"mode": fields.String(required=True, enum=ALLOW_CREATE_APP_MODES, description="App mode"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App created successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_detail_model)
|
||||
@marshal_with(app_detail_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create app"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = CreateAppPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if "mode" not in args or args["mode"] is None:
|
||||
raise BadRequest("mode is required")
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(current_tenant_id, args.model_dump(), current_user)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
if current_user.current_tenant_id is None:
|
||||
raise ValueError("current_user.current_tenant_id cannot be None")
|
||||
app = app_service.create_app(current_user.current_tenant_id, args, current_user)
|
||||
|
||||
return app, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>")
|
||||
class AppApi(Resource):
|
||||
@console_ns.doc("get_app_detail")
|
||||
@console_ns.doc(description="Get application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Success", app_detail_with_site_model)
|
||||
@api.doc("get_app_detail")
|
||||
@api.doc(description="Get application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Success", app_detail_fields_with_site)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
"""Get app detail"""
|
||||
app_service = AppService()
|
||||
@ -381,50 +181,79 @@ class AppApi(Resource):
|
||||
|
||||
return app_model
|
||||
|
||||
@console_ns.doc("update_app")
|
||||
@console_ns.doc(description="Update application details")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[UpdateAppPayload.__name__])
|
||||
@console_ns.response(200, "App updated successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.doc("update_app")
|
||||
@api.doc(description="Update application details")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateAppRequest",
|
||||
{
|
||||
"name": fields.String(required=True, description="App name"),
|
||||
"description": fields.String(description="App description (max 400 chars)"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
"use_icon_as_answer_icon": fields.Boolean(description="Use icon as answer icon"),
|
||||
"max_active_requests": fields.Integer(description="Maximum active requests"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "App updated successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def put(self, app_model):
|
||||
"""Update app"""
|
||||
args = UpdateAppPayload.model_validate(console_ns.payload)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
# Construct ArgsDict from parsed arguments
|
||||
from services.app_service import AppService as AppServiceType
|
||||
|
||||
args_dict: AppService.ArgsDict = {
|
||||
"name": args.name,
|
||||
"description": args.description or "",
|
||||
"icon_type": args.icon_type or "",
|
||||
"icon": args.icon or "",
|
||||
"icon_background": args.icon_background or "",
|
||||
"use_icon_as_answer_icon": args.use_icon_as_answer_icon or False,
|
||||
"max_active_requests": args.max_active_requests or 0,
|
||||
args_dict: AppServiceType.ArgsDict = {
|
||||
"name": args["name"],
|
||||
"description": args.get("description", ""),
|
||||
"icon_type": args.get("icon_type", ""),
|
||||
"icon": args.get("icon", ""),
|
||||
"icon_background": args.get("icon_background", ""),
|
||||
"use_icon_as_answer_icon": args.get("use_icon_as_answer_icon", False),
|
||||
"max_active_requests": args.get("max_active_requests", 0),
|
||||
}
|
||||
app_model = app_service.update_app(app_model, args_dict)
|
||||
|
||||
return app_model
|
||||
|
||||
@console_ns.doc("delete_app")
|
||||
@console_ns.doc(description="Delete application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(204, "App deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("delete_app")
|
||||
@api.doc(description="Delete application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(204, "App deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model):
|
||||
"""Delete app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
@ -433,37 +262,55 @@ class AppApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
class AppCopyApi(Resource):
|
||||
@console_ns.doc("copy_app")
|
||||
@console_ns.doc(description="Create a copy of an existing application")
|
||||
@console_ns.doc(params={"app_id": "Application ID to copy"})
|
||||
@console_ns.expect(console_ns.models[CopyAppPayload.__name__])
|
||||
@console_ns.response(201, "App copied successfully", app_detail_with_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("copy_app")
|
||||
@api.doc(description="Create a copy of an existing application")
|
||||
@api.doc(params={"app_id": "Application ID to copy"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CopyAppRequest",
|
||||
{
|
||||
"name": fields.String(description="Name for the copied app"),
|
||||
"description": fields.String(description="Description for the copied app"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon": fields.String(description="Icon"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(201, "App copied successfully", app_detail_fields_with_site)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_with_site_model)
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def post(self, app_model):
|
||||
"""Copy app"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=validate_description_length, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=current_user,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT.value,
|
||||
yaml_content=yaml_content,
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
icon_type=args.icon_type,
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
name=args.get("name"),
|
||||
description=args.get("description"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@ -475,131 +322,178 @@ class AppCopyApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/export")
|
||||
class AppExportApi(Resource):
|
||||
@console_ns.doc("export_app")
|
||||
@console_ns.doc(description="Export application configuration as DSL")
|
||||
@console_ns.doc(params={"app_id": "Application ID to export"})
|
||||
@console_ns.expect(console_ns.models[AppExportQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("export_app")
|
||||
@api.doc(description="Export application configuration as DSL")
|
||||
@api.doc(params={"app_id": "Application ID to export"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("include_secret", type=bool, location="args", default=False, help="Include secrets in export")
|
||||
.add_argument("workflow_id", type=str, location="args", help="Specific workflow ID to export")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"App exported successfully",
|
||||
console_ns.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
api.model("AppExportResponse", {"data": fields.String(description="DSL export data")}),
|
||||
)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
"""Export app"""
|
||||
args = AppExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model,
|
||||
include_secret=args.include_secret,
|
||||
workflow_id=args.workflow_id,
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
@console_ns.doc(description="Check if app name is available")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppNamePayload.__name__])
|
||||
@console_ns.response(200, "Name availability checked")
|
||||
@api.doc("check_app_name")
|
||||
@api.doc(description="Check if app name is available")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(api.parser().add_argument("name", type=str, required=True, location="args", help="Name to check"))
|
||||
@api.response(200, "Name availability checked")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
args = AppNamePayload.model_validate(console_ns.payload)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_name(app_model, args.name)
|
||||
app_model = app_service.update_app_name(app_model, args["name"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/icon")
|
||||
class AppIconApi(Resource):
|
||||
@console_ns.doc("update_app_icon")
|
||||
@console_ns.doc(description="Update application icon")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppIconPayload.__name__])
|
||||
@console_ns.response(200, "Icon updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("update_app_icon")
|
||||
@api.doc(description="Update application icon")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppIconRequest",
|
||||
{
|
||||
"icon": fields.String(required=True, description="Icon data"),
|
||||
"icon_type": fields.String(description="Icon type"),
|
||||
"icon_background": fields.String(description="Icon background color"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Icon updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
args = AppIconPayload.model_validate(console_ns.payload or {})
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
|
||||
app_model = app_service.update_app_icon(app_model, args.get("icon") or "", args.get("icon_background") or "")
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site-enable")
|
||||
class AppSiteStatus(Resource):
|
||||
@console_ns.doc("update_app_site_status")
|
||||
@console_ns.doc(description="Enable or disable app site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppSiteStatusPayload.__name__])
|
||||
@console_ns.response(200, "Site status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("update_app_site_status")
|
||||
@api.doc(description="Enable or disable app site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppSiteStatusRequest", {"enable_site": fields.Boolean(required=True, description="Enable or disable site")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "Site status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
args = AppSiteStatusPayload.model_validate(console_ns.payload)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_site", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_site_status(app_model, args.enable_site)
|
||||
app_model = app_service.update_app_site_status(app_model, args["enable_site"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/api-enable")
|
||||
class AppApiStatus(Resource):
|
||||
@console_ns.doc("update_app_api_status")
|
||||
@console_ns.doc(description="Enable or disable app API")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppApiStatusPayload.__name__])
|
||||
@console_ns.response(200, "API status updated successfully", app_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("update_app_api_status")
|
||||
@api.doc(description="Enable or disable app API")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppApiStatusRequest", {"enable_api": fields.Boolean(required=True, description="Enable or disable API")}
|
||||
)
|
||||
)
|
||||
@api.response(200, "API status updated successfully", app_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_model)
|
||||
@marshal_with(app_detail_fields)
|
||||
def post(self, app_model):
|
||||
args = AppApiStatusPayload.model_validate(console_ns.payload)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enable_api", type=bool, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
app_model = app_service.update_app_api_status(app_model, args.enable_api)
|
||||
app_model = app_service.update_app_api_status(app_model, args["enable_api"])
|
||||
|
||||
return app_model
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trace")
|
||||
class AppTraceApi(Resource):
|
||||
@console_ns.doc("get_app_trace")
|
||||
@console_ns.doc(description="Get app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Trace configuration retrieved successfully")
|
||||
@api.doc("get_app_trace")
|
||||
@api.doc(description="Get app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Trace configuration retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -609,24 +503,36 @@ class AppTraceApi(Resource):
|
||||
|
||||
return app_trace_config
|
||||
|
||||
@console_ns.doc("update_app_trace")
|
||||
@console_ns.doc(description="Update app tracing configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[AppTracePayload.__name__])
|
||||
@console_ns.response(200, "Trace configuration updated successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("update_app_trace")
|
||||
@api.doc(description="Update app tracing configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppTraceRequest",
|
||||
{
|
||||
"enabled": fields.Boolean(required=True, description="Enable or disable tracing"),
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Trace configuration updated successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, app_id):
|
||||
# add app trace
|
||||
args = AppTracePayload.model_validate(console_ns.payload)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("enabled", type=bool, required=True, location="json")
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
OpsTraceManager.update_app_tracing_config(
|
||||
app_id=app_id,
|
||||
enabled=args.enabled,
|
||||
tracing_provider=args.tracing_provider,
|
||||
enabled=args["enabled"],
|
||||
tracing_provider=args["tracing_provider"],
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,20 +1,20 @@
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import (
|
||||
app_import_check_dependencies_fields,
|
||||
app_import_fields,
|
||||
leaked_dependency_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from fields.app_fields import app_import_check_dependencies_fields, app_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -22,52 +22,36 @@ from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
|
||||
|
||||
app_import_model = console_ns.model("AppImport", app_import_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
|
||||
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
|
||||
app_import_check_dependencies_model = console_ns.model(
|
||||
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
|
||||
)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("app_id", type=str, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/imports")
|
||||
class AppImportApi(Resource):
|
||||
@console_ns.expect(parser)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@marshal_with(app_import_fields)
|
||||
@cloud_edition_billing_resource_check("apps")
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("app_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_app(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
@ -86,9 +70,9 @@ class AppImportApi(Resource):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -98,22 +82,22 @@ class AppImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -124,9 +108,11 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@login_required
|
||||
@get_app_model
|
||||
@account_initialization_required
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_import_check_dependencies_fields)
|
||||
def get(self, app_model: App):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
@ -5,7 +5,7 @@ from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@ -36,16 +36,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@console_ns.doc("chat_message_audio_transcript")
|
||||
@console_ns.doc(description="Transcript audio to text for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.response(
|
||||
@api.doc("chat_message_audio_transcript")
|
||||
@api.doc(description="Transcript audio to text for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Audio transcription successful",
|
||||
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
api.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
|
||||
)
|
||||
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@console_ns.response(413, "Audio file too large")
|
||||
@api.response(400, "Bad request - No audio uploaded or unsupported type")
|
||||
@api.response(413, "Audio file too large")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -89,11 +89,11 @@ class ChatMessageAudioApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio")
|
||||
class ChatMessageTextApi(Resource):
|
||||
@console_ns.doc("chat_message_text_to_speech")
|
||||
@console_ns.doc(description="Convert text to speech for chat messages")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("chat_message_text_to_speech")
|
||||
@api.doc(description="Convert text to speech for chat messages")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TextToSpeechRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
@ -103,21 +103,19 @@ class ChatMessageTextApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Text to speech conversion successful")
|
||||
@console_ns.response(400, "Bad request - Invalid parameters")
|
||||
@api.response(200, "Text to speech conversion successful")
|
||||
@api.response(400, "Bad request - Invalid parameters")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model: App):
|
||||
try:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
@ -156,23 +154,20 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/text-to-audio/voices")
|
||||
class TextModesApi(Resource):
|
||||
@console_ns.doc("get_text_to_speech_voices")
|
||||
@console_ns.doc(description="Get available TTS voices for a specific language")
|
||||
@console_ns.doc(params={"app_id": "App ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument("language", type=str, required=True, location="args", help="Language code")
|
||||
)
|
||||
@console_ns.response(
|
||||
200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices"))
|
||||
)
|
||||
@console_ns.response(400, "Invalid language parameter")
|
||||
@api.doc("get_text_to_speech_voices")
|
||||
@api.doc(description="Get available TTS voices for a specific language")
|
||||
@api.doc(params={"app_id": "App ID"})
|
||||
@api.expect(api.parser().add_argument("language", type=str, required=True, location="args", help="Language code"))
|
||||
@api.response(200, "TTS voices retrieved successfully", fields.List(fields.Raw(description="Available voices")))
|
||||
@api.response(400, "Invalid language parameter")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
try:
|
||||
parser = reqparse.RequestParser().add_argument("language", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("language", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
response = AudioService.transcript_tts_voices(
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
@ -17,8 +15,9 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
@ -33,66 +32,48 @@ from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config")
|
||||
files: list[Any] | None = Field(default=None, description="Uploaded files")
|
||||
response_mode: Literal["blocking", "streaming"] = Field(default="blocking", description="Response mode")
|
||||
retriever_from: str = Field(default="dev", description="Retriever source")
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(default="", description="Query text")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseMessagePayload):
|
||||
query: str = Field(..., description="User query")
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionMessagePayload.__name__,
|
||||
CompletionMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatMessagePayload.__name__, ChatMessagePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages")
|
||||
class CompletionMessageApi(Resource):
|
||||
@console_ns.doc("create_completion_message")
|
||||
@console_ns.doc(description="Generate completion message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
@console_ns.response(200, "Completion generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App not found")
|
||||
@api.doc("create_completion_message")
|
||||
@api.doc(description="Generate completion message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CompletionMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(description="Query text", default=""),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Completion generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model):
|
||||
args_model = CompletionMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@ -127,10 +108,10 @@ class CompletionMessageApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop")
|
||||
class CompletionMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_completion_message")
|
||||
@console_ns.doc(description="Stop a running completion message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@api.doc("stop_completion_message")
|
||||
@api.doc(description="Stop a running completion message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -138,36 +119,57 @@ class CompletionMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageApi(Resource):
|
||||
@console_ns.doc("create_chat_message")
|
||||
@console_ns.doc(description="Generate chat message for debugging")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagePayload.__name__])
|
||||
@console_ns.response(200, "Chat message generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(404, "App or conversation not found")
|
||||
@api.doc("create_chat_message")
|
||||
@api.doc(description="Generate chat message for debugging")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ChatMessageRequest",
|
||||
{
|
||||
"inputs": fields.Raw(required=True, description="Input variables"),
|
||||
"query": fields.String(required=True, description="User query"),
|
||||
"files": fields.List(fields.Raw(), description="Uploaded files"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"conversation_id": fields.String(description="Conversation ID"),
|
||||
"parent_message_id": fields.String(description="Parent message ID"),
|
||||
"response_mode": fields.String(enum=["blocking", "streaming"], description="Response mode"),
|
||||
"retriever_from": fields.String(default="dev", description="Retriever source"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Chat message generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(404, "App or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
args = args_model.model_dump(exclude_none=True, by_alias=True)
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
streaming = args_model.response_mode != "blocking"
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] != "blocking"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -208,10 +210,10 @@ class ChatMessageApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop")
|
||||
class ChatMessageStopApi(Resource):
|
||||
@console_ns.doc("stop_chat_message")
|
||||
@console_ns.doc(description="Stop a running chat message generation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@api.doc("stop_chat_message")
|
||||
@api.doc(description="Stop a running chat message generation")
|
||||
@api.doc(params={"app_id": "Application ID", "task_id": "Task ID to stop"})
|
||||
@api.response(200, "Task stopped successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -219,12 +221,6 @@ class ChatMessageStopApi(Resource):
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1,377 +1,116 @@
|
||||
from typing import Literal
|
||||
from datetime import datetime
|
||||
|
||||
import pytz # pip install pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import joinedload
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import MessageTextField
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from fields.conversation_fields import (
|
||||
conversation_detail_fields,
|
||||
conversation_message_detail_fields,
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import Account, Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class BaseConversationQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword")
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
annotation_status: Literal["annotated", "not_annotated", "all"] = Field(
|
||||
default="all", description="Annotation status filter"
|
||||
)
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
class CompletionConversationQuery(BaseConversationQuery):
|
||||
pass
|
||||
|
||||
|
||||
class ChatConversationQuery(BaseConversationQuery):
|
||||
message_count_gte: int | None = Field(default=None, ge=1, description="Minimum message count")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort field and direction"
|
||||
)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
CompletionConversationQuery.__name__,
|
||||
CompletionConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatConversationQuery.__name__,
|
||||
ChatConversationQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
feedback_stat_model = console_ns.model(
|
||||
"FeedbackStat",
|
||||
{
|
||||
"like": fields.Integer,
|
||||
"dislike": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
status_count_model = console_ns.model(
|
||||
"StatusCount",
|
||||
{
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
simple_model_config_model = console_ns.model(
|
||||
"SimpleModelConfig",
|
||||
{
|
||||
"model": fields.Raw(attribute="model_dict"),
|
||||
"pre_prompt": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
model_config_model = console_ns.model(
|
||||
"ModelConfig",
|
||||
{
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"model": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"pre_prompt": fields.String,
|
||||
"agent_mode": fields.Raw,
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Simple message detail model
|
||||
simple_message_detail_model = console_ns.model(
|
||||
"SimpleMessageDetail",
|
||||
{
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": MessageTextField,
|
||||
"answer": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Conversation models
|
||||
conversation_fields_model = console_ns.model(
|
||||
"Conversation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String(),
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"message": fields.Nested(simple_message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_pagination_model = console_ns.model(
|
||||
"ConversationPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_fields_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_message_detail_model = console_ns.model(
|
||||
"ConversationMessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message": fields.Nested(message_detail_model, attribute="first_message"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_model = console_ns.model(
|
||||
"ConversationWithSummary",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_end_user_session_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"from_account_name": fields.String,
|
||||
"name": fields.String,
|
||||
"summary": fields.String(attribute="summary_or_query"),
|
||||
"read_at": TimestampField,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"model_config": fields.Nested(simple_model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"status_count": fields.Nested(status_count_model),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_with_summary_pagination_model = console_ns.model(
|
||||
"ConversationWithSummaryPagination",
|
||||
{
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(conversation_with_summary_model), attribute="items"),
|
||||
},
|
||||
)
|
||||
|
||||
conversation_detail_model = console_ns.model(
|
||||
"ConversationDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"annotated": fields.Boolean,
|
||||
"introduction": fields.String,
|
||||
"model_config": fields.Nested(model_config_model),
|
||||
"message_count": fields.Integer,
|
||||
"user_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
"admin_feedback_stats": fields.Nested(feedback_stat_model),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations")
|
||||
class CompletionConversationApi(Resource):
|
||||
@console_ns.doc("list_completion_conversations")
|
||||
@console_ns.doc(description="Get completion conversations with pagination and filtering")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[CompletionConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("list_completion_conversations")
|
||||
@api.doc(description="Get completion conversations with pagination and filtering")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", conversation_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_pagination_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(conversation_pagination_fields)
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
if args["keyword"]:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
or_(
|
||||
Message.query.ilike(f"%{args.keyword}%"),
|
||||
Message.answer.ilike(f"%{args.keyword}%"),
|
||||
Message.query.ilike(f"%{args['keyword']}%"),
|
||||
Message.answer.ilike(f"%{args['keyword']}%"),
|
||||
)
|
||||
)
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
|
||||
# FIXME, the type ignore in this file
|
||||
if args.annotation_status == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
@ -380,46 +119,49 @@ class CompletionConversationApi(Resource):
|
||||
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>")
|
||||
class CompletionConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_completion_conversation")
|
||||
@console_ns.doc(description="Get completion conversation details with messages")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_message_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@api.doc("get_completion_conversation")
|
||||
@api.doc(description="Get completion conversation details with messages")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_message_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@marshal_with(conversation_message_detail_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(conversation_message_detail_fields)
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@console_ns.doc("delete_completion_conversation")
|
||||
@console_ns.doc(description="Delete a completion conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@api.doc("delete_completion_conversation")
|
||||
@api.doc(description="Delete a completion conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -429,21 +171,63 @@ class CompletionConversationDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
class ChatConversationApi(Resource):
|
||||
@console_ns.doc("list_chat_conversations")
|
||||
@console_ns.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatConversationQuery.__name__])
|
||||
@console_ns.response(200, "Success", conversation_with_summary_pagination_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("list_chat_conversations")
|
||||
@api.doc(description="Get chat conversations with pagination, filtering and summary")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("keyword", type=str, location="args", help="Search keyword")
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument(
|
||||
"annotation_status",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["annotated", "not_annotated", "all"],
|
||||
default="all",
|
||||
help="Annotation status filter",
|
||||
)
|
||||
.add_argument("message_count_gte", type=int, location="args", help="Minimum message count")
|
||||
.add_argument("page", type=int, location="args", default=1, help="Page number")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Page size (1-100)")
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
location="args",
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
default="-updated_at",
|
||||
help="Sort field and direction",
|
||||
)
|
||||
)
|
||||
@api.response(200, "Success", conversation_with_summary_pagination_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_with_summary_pagination_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(conversation_with_summary_pagination_fields)
|
||||
def get(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = ChatConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
subquery = (
|
||||
db.session.query(
|
||||
@ -455,8 +239,8 @@ class ChatConversationApi(Resource):
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
if args["keyword"]:
|
||||
keyword_filter = f"%{args['keyword']}%"
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
@ -476,51 +260,58 @@ class ChatConversationApi(Resource):
|
||||
)
|
||||
|
||||
account = current_user
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
if start_datetime_utc:
|
||||
match args.sort_by:
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if end_datetime_utc:
|
||||
end_datetime_utc = end_datetime_utc.replace(second=59)
|
||||
match args.sort_by:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=59)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args.annotation_status == "annotated":
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
|
||||
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
|
||||
)
|
||||
elif args.annotation_status == "not_annotated":
|
||||
elif args["annotation_status"] == "not_annotated":
|
||||
query = (
|
||||
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(MessageAnnotation.id) == 0)
|
||||
)
|
||||
|
||||
if args.message_count_gte and args.message_count_gte >= 1:
|
||||
if args["message_count_gte"] and args["message_count_gte"] >= 1:
|
||||
query = (
|
||||
query.options(joinedload(Conversation.messages)) # type: ignore
|
||||
.join(Message, Message.conversation_id == Conversation.id)
|
||||
.group_by(Conversation.id)
|
||||
.having(func.count(Message.id) >= args.message_count_gte)
|
||||
.having(func.count(Message.id) >= args["message_count_gte"])
|
||||
)
|
||||
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
|
||||
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
|
||||
match args.sort_by:
|
||||
match args["sort_by"]:
|
||||
case "created_at":
|
||||
query = query.order_by(Conversation.created_at.asc())
|
||||
case "-created_at":
|
||||
@ -532,46 +323,49 @@ class ChatConversationApi(Resource):
|
||||
case _:
|
||||
query = query.order_by(Conversation.created_at.desc())
|
||||
|
||||
conversations = db.paginate(query, page=args.page, per_page=args.limit, error_out=False)
|
||||
conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>")
|
||||
class ChatConversationDetailApi(Resource):
|
||||
@console_ns.doc("get_chat_conversation")
|
||||
@console_ns.doc(description="Get chat conversation details")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(200, "Success", conversation_detail_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@api.doc("get_chat_conversation")
|
||||
@api.doc(description="Get chat conversation details")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(200, "Success", conversation_detail_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(conversation_detail_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(conversation_detail_fields)
|
||||
def get(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_model, conversation_id)
|
||||
|
||||
@console_ns.doc("delete_chat_conversation")
|
||||
@console_ns.doc(description="Delete a chat conversation")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@console_ns.response(204, "Conversation deleted successfully")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@api.doc("delete_chat_conversation")
|
||||
@api.doc(description="Delete a chat conversation")
|
||||
@api.doc(params={"app_id": "Application ID", "conversation_id": "Conversation ID"})
|
||||
@api.response(204, "Conversation deleted successfully")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -580,7 +374,6 @@ class ChatConversationDetailApi(Resource):
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
|
||||
|
||||
@ -1,68 +1,47 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_variable_fields import (
|
||||
conversation_variable_fields,
|
||||
paginated_conversation_variable_fields,
|
||||
)
|
||||
from fields.conversation_variable_fields import paginated_conversation_variable_fields
|
||||
from libs.login import login_required
|
||||
from models import ConversationVariable
|
||||
from models.model import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID to filter variables")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
ConversationVariablesQuery.__name__,
|
||||
ConversationVariablesQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register base model first
|
||||
conversation_variable_model = console_ns.model("ConversationVariable", conversation_variable_fields)
|
||||
|
||||
# For nested models, need to replace nested dict with registered model
|
||||
paginated_conversation_variable_fields_copy = paginated_conversation_variable_fields.copy()
|
||||
paginated_conversation_variable_fields_copy["data"] = fields.List(
|
||||
fields.Nested(conversation_variable_model), attribute="data"
|
||||
)
|
||||
paginated_conversation_variable_model = console_ns.model(
|
||||
"PaginatedConversationVariable", paginated_conversation_variable_fields_copy
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/conversation-variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ConversationVariablesQuery.__name__])
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_model)
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"conversation_id", type=str, location="args", help="Conversation ID to filter variables"
|
||||
)
|
||||
)
|
||||
@api.response(200, "Conversation variables retrieved successfully", paginated_conversation_variable_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
@marshal_with(paginated_conversation_variable_model)
|
||||
@marshal_with(paginated_conversation_variable_fields)
|
||||
def get(self, app_model):
|
||||
args = ConversationVariablesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
stmt = (
|
||||
select(ConversationVariable)
|
||||
.where(ConversationVariable.app_id == app_model.id)
|
||||
.order_by(ConversationVariable.created_at)
|
||||
)
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args.conversation_id)
|
||||
if args["conversation_id"]:
|
||||
stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"])
|
||||
else:
|
||||
raise ValueError("conversation_id is required")
|
||||
|
||||
# NOTE: This is a temporary solution to avoid performance issues.
|
||||
page = 1
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -13,80 +12,50 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
code_language: str = Field(default="javascript", description="Programming language for code generation")
|
||||
|
||||
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class InstructionGeneratePayload(BaseModel):
|
||||
flow_id: str = Field(..., description="Workflow/Flow ID")
|
||||
node_id: str = Field(default="", description="Node ID for workflow context")
|
||||
current: str = Field(default="", description="Current instruction text")
|
||||
language: str = Field(default="javascript", description="Programming language (javascript/python)")
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(RuleGeneratePayload)
|
||||
reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
class RuleGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_config")
|
||||
@console_ns.doc(description="Generate rule configuration using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Rule configuration generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@api.doc("generate_rule_config")
|
||||
@api.doc(description="Generate rule configuration using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Rule generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Rule configuration generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=args["no_variable"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -102,25 +71,42 @@ class RuleGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/rule-code-generate")
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@console_ns.doc("generate_rule_code")
|
||||
@console_ns.doc(description="Generate code rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleCodeGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Code rules generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@api.doc("generate_rule_code")
|
||||
@api.doc(description="Generate code rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RuleCodeGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Code generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"no_variable": fields.Boolean(required=True, default=False, description="Whether to exclude variables"),
|
||||
"code_language": fields.String(
|
||||
default="javascript", description="Programming language for code generation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Code rules generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -136,24 +122,35 @@ class RuleCodeGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/rule-structured-output-generate")
|
||||
class RuleStructuredOutputGenerateApi(Resource):
|
||||
@console_ns.doc("generate_structured_output")
|
||||
@console_ns.doc(description="Generate structured output rules using LLM")
|
||||
@console_ns.expect(console_ns.models[RuleStructuredOutputPayload.__name__])
|
||||
@console_ns.response(200, "Structured output generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@api.doc("generate_structured_output")
|
||||
@api.doc(description="Generate structured output rules using LLM")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"StructuredOutputGenerateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Structured output generation instruction"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Structured output generated successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -169,79 +166,101 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/instruction-generate")
|
||||
class InstructionGenerateApi(Resource):
|
||||
@console_ns.doc("generate_instruction")
|
||||
@console_ns.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@console_ns.expect(console_ns.models[InstructionGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Instruction generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@api.doc("generate_instruction")
|
||||
@api.doc(description="Generate instruction for workflow nodes or general use")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionGenerateRequest",
|
||||
{
|
||||
"flow_id": fields.String(required=True, description="Workflow/Flow ID"),
|
||||
"node_id": fields.String(description="Node ID for workflow context"),
|
||||
"current": fields.String(description="Current instruction text"),
|
||||
"language": fields.String(default="javascript", description="Programming language (javascript/python)"),
|
||||
"instruction": fields.String(required=True, description="Instruction for generation"),
|
||||
"model_config": fields.Raw(required=True, description="Model configuration"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Instruction generated successfully")
|
||||
@api.response(400, "Invalid request parameters or flow/workflow not found")
|
||||
@api.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("flow_id", type=str, required=True, default="", location="json")
|
||||
parser.add_argument("node_id", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("current", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="javascript", location="json")
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("ideal_output", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
code_template = (
|
||||
Python3CodeProvider.get_default_code()
|
||||
if args["language"] == "python"
|
||||
else (JavascriptCodeProvider.get_default_code())
|
||||
if args["language"] == "javascript"
|
||||
else ""
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args.flow_id} not found"}, 400
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"error": f"workflow {args.flow_id} not found"}, 400
|
||||
return {"error": f"workflow {args['flow_id']} not found"}, 400
|
||||
nodes: Sequence = workflow.graph_dict["nodes"]
|
||||
node = [node for node in nodes if node["id"] == args.node_id]
|
||||
node = [node for node in nodes if node["id"] == args["node_id"]]
|
||||
if len(node) == 0:
|
||||
return {"error": f"node {args.node_id} not found"}, 400
|
||||
return {"error": f"node {args['node_id']} not found"}, 400
|
||||
node_type = node[0]["data"]["type"]
|
||||
match node_type:
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
no_variable=True,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["language"],
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
if args["node_id"] == "" and args["current"] != "": # For legacy app without a workflow
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
)
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
if args["node_id"] != "" and args["current"] != "": # For workflow node
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
node_id=args.node_id,
|
||||
current=args.current,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
flow_id=args["flow_id"],
|
||||
node_id=args["node_id"],
|
||||
current=args["current"],
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
ideal_output=args["ideal_output"],
|
||||
workflow_service=WorkflowService(),
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
@ -257,17 +276,27 @@ class InstructionGenerateApi(Resource):
|
||||
|
||||
@console_ns.route("/instruction-generate/template")
|
||||
class InstructionGenerationTemplateApi(Resource):
|
||||
@console_ns.doc("get_instruction_template")
|
||||
@console_ns.doc(description="Get instruction generation template")
|
||||
@console_ns.expect(console_ns.models[InstructionTemplatePayload.__name__])
|
||||
@console_ns.response(200, "Template retrieved successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.doc("get_instruction_template")
|
||||
@api.doc(description="Get instruction generation template")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"InstructionTemplateRequest",
|
||||
{
|
||||
"instruction": fields.String(required=True, description="Template instruction"),
|
||||
"ideal_output": fields.String(description="Expected ideal output"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Template retrieved successfully")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionTemplatePayload.model_validate(console_ns.payload)
|
||||
match args.type:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
match args["type"]:
|
||||
case "prompt":
|
||||
from core.llm_generator.prompts import INSTRUCTION_GENERATE_TEMPLATE_PROMPT
|
||||
|
||||
@ -277,4 +306,4 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
raise ValueError(f"Invalid type: {args['type']}")
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_server_model = console_ns.model("AppServer", app_server_fields)
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
@ -23,24 +21,24 @@ class AppMCPServerStatus(StrEnum):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/server")
|
||||
class AppMCPServerController(Resource):
|
||||
@console_ns.doc("get_app_mcp_server")
|
||||
@console_ns.doc(description="Get MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
|
||||
@api.doc("get_app_mcp_server")
|
||||
@api.doc(description="Get MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "MCP server configuration retrieved successfully", app_server_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_model)
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@console_ns.doc("create_app_mcp_server")
|
||||
@console_ns.doc(description="Create MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("create_app_mcp_server")
|
||||
@api.doc(description="Create MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCreateRequest",
|
||||
{
|
||||
"description": fields.String(description="Server description"),
|
||||
@ -48,21 +46,19 @@ class AppMCPServerController(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.response(201, "MCP server configuration created successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@login_required
|
||||
@setup_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
)
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
description = args.get("description")
|
||||
@ -75,18 +71,18 @@ class AppMCPServerController(Resource):
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@console_ns.doc("update_app_mcp_server")
|
||||
@console_ns.doc(description="Update MCP server configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_app_mcp_server")
|
||||
@api.doc(description="Update MCP server configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerUpdateRequest",
|
||||
{
|
||||
"id": fields.String(required=True, description="Server ID"),
|
||||
@ -96,23 +92,22 @@ class AppMCPServerController(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@get_app_model
|
||||
@login_required
|
||||
@api.response(200, "MCP server configuration updated successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("id", type=str, required=True, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("parameters", type=dict, required=True, location="json")
|
||||
.add_argument("status", type=str, required=False, location="json")
|
||||
)
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
@ -137,23 +132,23 @@ class AppMCPServerController(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@console_ns.doc("refresh_app_mcp_server")
|
||||
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@console_ns.doc(params={"server_id": "Server ID"})
|
||||
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "Server not found")
|
||||
@api.doc("refresh_app_mcp_server")
|
||||
@api.doc(description="Refresh MCP server configuration and regenerate server code")
|
||||
@api.doc(params={"server_id": "Server ID"})
|
||||
@api.response(200, "MCP server refreshed successfully", app_server_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "Server not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_model)
|
||||
@edit_permission_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = (
|
||||
db.session.query(AppMCPServer)
|
||||
.where(AppMCPServer.id == server_id)
|
||||
.where(AppMCPServer.tenant_id == current_tenant_id)
|
||||
.where(AppMCPServer.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not server:
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -18,233 +16,74 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from fields.conversation_fields import annotation_fields, message_detail_fields
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.annotation_service import AppAnnotationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ChatMessagesQuery(BaseModel):
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
first_id: str | None = Field(default=None, description="First message ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
|
||||
|
||||
@field_validator("first_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@field_validator("conversation_id", "first_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str) -> str:
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class FeedbackExportQuery(BaseModel):
|
||||
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
|
||||
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
|
||||
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
|
||||
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
|
||||
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
|
||||
|
||||
@field_validator("has_comment", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool | None:
|
||||
if isinstance(value, bool) or value is None:
|
||||
return value
|
||||
lowered = value.lower()
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("has_comment must be a boolean value")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(ChatMessagesQuery)
|
||||
reg(MessageFeedbackPayload)
|
||||
reg(FeedbackExportQuery)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model(
|
||||
"SimpleAccount",
|
||||
{
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
message_file_model = console_ns.model(
|
||||
"MessageFile",
|
||||
{
|
||||
"id": fields.String,
|
||||
"filename": fields.String,
|
||||
"type": fields.String,
|
||||
"url": fields.String,
|
||||
"mime_type": fields.String,
|
||||
"size": fields.Integer,
|
||||
"transfer_method": fields.String,
|
||||
"belongs_to": fields.String(default="user"),
|
||||
"upload_file_id": fields.String(default=None),
|
||||
},
|
||||
)
|
||||
|
||||
agent_thought_model = console_ns.model(
|
||||
"AgentThought",
|
||||
{
|
||||
"id": fields.String,
|
||||
"chain_id": fields.String,
|
||||
"message_id": fields.String,
|
||||
"position": fields.Integer,
|
||||
"thought": fields.String,
|
||||
"tool": fields.String,
|
||||
"tool_labels": fields.Raw,
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"files": fields.List(fields.String),
|
||||
},
|
||||
)
|
||||
|
||||
# Models that depend on simple_account_model
|
||||
feedback_model = console_ns.model(
|
||||
"Feedback",
|
||||
{
|
||||
"rating": fields.String,
|
||||
"content": fields.String,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
},
|
||||
)
|
||||
|
||||
annotation_model = console_ns.model(
|
||||
"Annotation",
|
||||
{
|
||||
"id": fields.String,
|
||||
"question": fields.String,
|
||||
"content": fields.String,
|
||||
"account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
annotation_hit_history_model = console_ns.model(
|
||||
"AnnotationHitHistory",
|
||||
{
|
||||
"annotation_id": fields.String(attribute="id"),
|
||||
"annotation_create_account": fields.Nested(simple_account_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
},
|
||||
)
|
||||
|
||||
# Message detail model that depends on multiple models
|
||||
message_detail_model = console_ns.model(
|
||||
"MessageDetail",
|
||||
{
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"message": fields.Raw,
|
||||
"message_tokens": fields.Integer,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"answer_tokens": fields.Integer,
|
||||
"provider_response_latency": fields.Float,
|
||||
"from_source": fields.String,
|
||||
"from_end_user_id": fields.String,
|
||||
"from_account_id": fields.String,
|
||||
"feedbacks": fields.List(fields.Nested(feedback_model)),
|
||||
"workflow_run_id": fields.String,
|
||||
"annotation": fields.Nested(annotation_model, allow_null=True),
|
||||
"annotation_hit_history": fields.Nested(annotation_hit_history_model, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
},
|
||||
)
|
||||
|
||||
# Message infinite scroll pagination model
|
||||
message_infinite_scroll_pagination_model = console_ns.model(
|
||||
"MessageInfiniteScrollPagination",
|
||||
{
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_model)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages")
|
||||
class ChatMessageListApi(Resource):
|
||||
@console_ns.doc("list_chat_messages")
|
||||
@console_ns.doc(description="Get chat messages for a conversation with pagination")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
|
||||
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
|
||||
@console_ns.response(404, "Conversation not found")
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
"limit": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(message_detail_fields)),
|
||||
}
|
||||
|
||||
@api.doc("list_chat_messages")
|
||||
@api.doc(description="Get chat messages for a conversation with pagination")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
|
||||
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
|
||||
)
|
||||
@api.response(200, "Success", message_infinite_scroll_pagination_fields)
|
||||
@api.response(404, "Conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model):
|
||||
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("conversation_id", required=True, type=uuid_value, location="args")
|
||||
parser.add_argument("first_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
|
||||
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
if args.first_id:
|
||||
if args["first_id"]:
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
|
||||
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -259,7 +98,7 @@ class ChatMessageListApi(Resource):
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
@ -267,12 +106,12 @@ class ChatMessageListApi(Resource):
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(args.limit)
|
||||
.limit(args["limit"])
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
if len(history_messages) == args.limit:
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
@ -290,28 +129,40 @@ class ChatMessageListApi(Resource):
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@console_ns.doc("create_message_feedback")
|
||||
@console_ns.doc(description="Create or update message feedback (like/dislike)")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
|
||||
@console_ns.response(200, "Feedback updated successfully")
|
||||
@console_ns.response(404, "Message not found")
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@api.doc("create_message_feedback")
|
||||
@api.doc(description="Create or update message feedback (like/dislike)")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageFeedbackRequest",
|
||||
{
|
||||
"message_id": fields.String(required=True, description="Message ID"),
|
||||
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Feedback updated successfully")
|
||||
@api.response(404, "Message not found")
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user is None:
|
||||
raise Forbidden()
|
||||
|
||||
args = MessageFeedbackPayload.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=True, type=uuid_value, location="json")
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = str(args.message_id)
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -320,21 +171,18 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
feedback = message.admin_feedback
|
||||
|
||||
if not args.rating and feedback:
|
||||
if not args["rating"] and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
elif not args.rating and not feedback:
|
||||
elif args["rating"] and feedback:
|
||||
feedback.rating = args["rating"]
|
||||
elif not args["rating"] and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
rating_value = args.rating
|
||||
if rating_value is None:
|
||||
raise ValueError("rating is required to create feedback")
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
rating=args["rating"],
|
||||
from_source="admin",
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
@ -345,15 +193,56 @@ class MessageFeedbackApi(Resource):
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations")
|
||||
class MessageAnnotationApi(Resource):
|
||||
@api.doc("create_message_annotation")
|
||||
@api.doc(description="Create message annotation")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MessageAnnotationRequest",
|
||||
{
|
||||
"message_id": fields.String(description="Message ID"),
|
||||
"question": fields.String(required=True, description="Question text"),
|
||||
"answer": fields.String(required=True, description="Answer text"),
|
||||
"annotation_reply": fields.Raw(description="Annotation reply"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.response(200, "Annotation created successfully", annotation_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("annotation")
|
||||
@get_app_model
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", required=False, type=uuid_value, location="json")
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
parser.add_argument("annotation_reply", required=False, type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id)
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/count")
|
||||
class MessageAnnotationCountApi(Resource):
|
||||
@console_ns.doc("get_annotation_count")
|
||||
@console_ns.doc(description="Get count of message annotations for the app")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(
|
||||
@api.doc("get_annotation_count")
|
||||
@api.doc(description="Get count of message annotations for the app")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Annotation count retrieved successfully",
|
||||
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
api.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
|
||||
)
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@ -367,23 +256,20 @@ class MessageAnnotationCountApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions")
|
||||
class MessageSuggestedQuestionApi(Resource):
|
||||
@console_ns.doc("get_message_suggested_questions")
|
||||
@console_ns.doc(description="Get suggested questions for a message")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(
|
||||
@api.doc("get_message_suggested_questions")
|
||||
@api.doc(description="Get suggested questions for a message")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Suggested questions retrieved successfully",
|
||||
console_ns.model(
|
||||
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
|
||||
),
|
||||
api.model("SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}),
|
||||
)
|
||||
@console_ns.response(404, "Message or conversation not found")
|
||||
@api.response(404, "Message or conversation not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
try:
|
||||
@ -411,59 +297,19 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
return {"data": questions}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
|
||||
class MessageFeedbackExportApi(Resource):
|
||||
@console_ns.doc("export_feedbacks")
|
||||
@console_ns.doc(description="Export user feedback data for Google Sheets")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
|
||||
@console_ns.response(200, "Feedback data exported successfully")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(500, "Internal server error")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
# Import the service function
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
try:
|
||||
export_data = FeedbackService.export_feedbacks(
|
||||
app_id=app_model.id,
|
||||
from_source=args.from_source,
|
||||
rating=args.rating,
|
||||
has_comment=args.has_comment,
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
format_type=args.format,
|
||||
)
|
||||
|
||||
return export_data
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Parameter validation error in feedback export")
|
||||
return {"error": f"Parameter validation error: {str(e)}"}, 400
|
||||
except Exception as e:
|
||||
logger.exception("Error exporting feedback data")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/messages/<uuid:message_id>")
|
||||
class MessageApi(Resource):
|
||||
@console_ns.doc("get_message")
|
||||
@console_ns.doc(description="Get message details by ID")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@console_ns.response(200, "Message retrieved successfully", message_detail_model)
|
||||
@console_ns.response(404, "Message not found")
|
||||
@get_app_model
|
||||
@api.doc("get_message")
|
||||
@api.doc(description="Get message details by ID")
|
||||
@api.doc(params={"app_id": "Application ID", "message_id": "Message ID"})
|
||||
@api.response(200, "Message retrieved successfully", message_detail_fields)
|
||||
@api.response(404, "Message not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(message_detail_model)
|
||||
def get(self, app_model, message_id: str):
|
||||
@get_app_model
|
||||
@marshal_with(message_detail_fields)
|
||||
def get(self, app_model, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
@ -2,29 +2,31 @@ import json
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@console_ns.doc("update_app_model_config")
|
||||
@console_ns.doc(description="Update application model configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_app_model_config")
|
||||
@api.doc(description="Update application model configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ModelConfigRequest",
|
||||
{
|
||||
"provider": fields.String(description="Model provider"),
|
||||
@ -42,20 +44,25 @@ class ModelConfigResource(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Model configuration updated successfully")
|
||||
@console_ns.response(400, "Invalid configuration")
|
||||
@console_ns.response(404, "App not found")
|
||||
@api.response(200, "Model configuration updated successfully")
|
||||
@api.response(400, "Invalid configuration")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION])
|
||||
def post(self, app_model):
|
||||
"""Modify app model config"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
|
||||
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
|
||||
# validate config
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
config=cast(dict, request.json),
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
@ -83,16 +90,16 @@ class ModelConfigResource(Resource):
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -117,7 +124,7 @@ class ModelConfigResource(Resource):
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
agent_tool_entity = AgentToolEntity.model_validate(tool)
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"
|
||||
@ -126,7 +133,7 @@ class ModelConfigResource(Resource):
|
||||
else:
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
agent_tool=agent_tool_entity,
|
||||
)
|
||||
@ -134,7 +141,7 @@ class ModelConfigResource(Resource):
|
||||
continue
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
@ -165,8 +172,6 @@ class ModelConfigResource(Resource):
|
||||
db.session.flush()
|
||||
|
||||
app_model.app_model_config_id = new_app_model_config.id
|
||||
app_model.updated_by = current_user.id
|
||||
app_model.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
@ -14,23 +14,24 @@ class TraceAppConfigApi(Resource):
|
||||
Manage trace app configurations
|
||||
"""
|
||||
|
||||
@console_ns.doc("get_trace_app_config")
|
||||
@console_ns.doc(description="Get tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
@api.doc("get_trace_app_config")
|
||||
@api.doc(description="Get tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
200, "Tracing configuration retrieved successfully", fields.Raw(description="Tracing configuration data")
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id):
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -41,11 +42,11 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@console_ns.doc("create_trace_app_config")
|
||||
@console_ns.doc(description="Create a new tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("create_trace_app_config")
|
||||
@api.doc(description="Create a new tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TraceConfigCreateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
@ -53,20 +54,18 @@ class TraceAppConfigApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
201, "Tracing configuration created successfully", fields.Raw(description="Created configuration data")
|
||||
)
|
||||
@console_ns.response(400, "Invalid request parameters or configuration already exists")
|
||||
@api.response(400, "Invalid request parameters or configuration already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, app_id):
|
||||
"""Create a new trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -81,11 +80,11 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@console_ns.doc("update_trace_app_config")
|
||||
@console_ns.doc(description="Update an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_trace_app_config")
|
||||
@api.doc(description="Update an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"TraceConfigUpdateRequest",
|
||||
{
|
||||
"tracing_provider": fields.String(required=True, description="Tracing provider name"),
|
||||
@ -93,18 +92,16 @@ class TraceAppConfigApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@api.response(200, "Tracing configuration updated successfully", fields.Raw(description="Success response"))
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, app_id):
|
||||
"""Update an existing trace app configuration"""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="json")
|
||||
parser.add_argument("tracing_config", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -117,22 +114,23 @@ class TraceAppConfigApi(Resource):
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@console_ns.doc("delete_trace_app_config")
|
||||
@console_ns.doc(description="Delete an existing tracing configuration for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.parser().add_argument(
|
||||
@api.doc("delete_trace_app_config")
|
||||
@api.doc(description="Delete an existing tracing configuration for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser().add_argument(
|
||||
"tracing_provider", type=str, required=True, location="args", help="Tracing provider name"
|
||||
)
|
||||
)
|
||||
@console_ns.response(204, "Tracing configuration deleted successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or configuration not found")
|
||||
@api.response(204, "Tracing configuration deleted successfully")
|
||||
@api.response(400, "Invalid request parameters or configuration not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, app_id):
|
||||
"""Delete an existing trace app configuration"""
|
||||
parser = reqparse.RequestParser().add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tracing_provider", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,61 +1,48 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Site
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
app_site_model = console_ns.model("AppSite", app_site_fields)
|
||||
from libs.login import login_required
|
||||
from models import Account, Site
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("title", type=str, required=False, location="json")
|
||||
.add_argument("icon_type", type=str, required=False, location="json")
|
||||
.add_argument("icon", type=str, required=False, location="json")
|
||||
.add_argument("icon_background", type=str, required=False, location="json")
|
||||
.add_argument("description", type=str, required=False, location="json")
|
||||
.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
.add_argument("copyright", type=str, required=False, location="json")
|
||||
.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
.add_argument(
|
||||
"customize_token_strategy",
|
||||
type=str,
|
||||
choices=["must", "allow", "not_allow"],
|
||||
required=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("title", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("default_language", type=supported_language, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme", type=str, required=False, location="json")
|
||||
parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json")
|
||||
parser.add_argument("customize_domain", type=str, required=False, location="json")
|
||||
parser.add_argument("copyright", type=str, required=False, location="json")
|
||||
parser.add_argument("privacy_policy", type=str, required=False, location="json")
|
||||
parser.add_argument("custom_disclaimer", type=str, required=False, location="json")
|
||||
parser.add_argument(
|
||||
"customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json"
|
||||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site")
|
||||
class AppSite(Resource):
|
||||
@console_ns.doc("update_app_site")
|
||||
@console_ns.doc(description="Update application site configuration")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_app_site")
|
||||
@api.doc(description="Update application site configuration")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AppSiteRequest",
|
||||
{
|
||||
"title": fields.String(description="Site title"),
|
||||
@ -79,18 +66,21 @@ class AppSite(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions")
|
||||
@console_ns.response(404, "App not found")
|
||||
@api.response(200, "Site configuration updated successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions")
|
||||
@api.response(404, "App not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
args = parse_app_site_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if not site:
|
||||
raise NotFound
|
||||
@ -117,6 +107,8 @@ class AppSite(Resource):
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@ -126,26 +118,30 @@ class AppSite(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
|
||||
class AppSiteAccessTokenReset(Resource):
|
||||
@console_ns.doc("reset_app_site_access_token")
|
||||
@console_ns.doc(description="Reset access token for application site")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Access token reset successfully", app_site_model)
|
||||
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@console_ns.response(404, "App or site not found")
|
||||
@api.doc("reset_app_site_access_token")
|
||||
@api.doc(description="Reset access token for application site")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Access token reset successfully", app_site_fields)
|
||||
@api.response(403, "Insufficient permissions (admin/owner required)")
|
||||
@api.response(404, "App or site not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_site_model)
|
||||
@marshal_with(app_site_fields)
|
||||
def post(self, app_model):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
|
||||
if not site:
|
||||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -1,48 +1,33 @@
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.helper import convert_datetime_to_date
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AppMode
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def empty_string_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
StatisticTimeRangeQuery.__name__,
|
||||
StatisticTimeRangeQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models import AppMode, Message
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-messages")
|
||||
class DailyMessageStatistic(Resource):
|
||||
@console_ns.doc("get_daily_message_statistics")
|
||||
@console_ns.doc(description="Get daily message statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_daily_message_statistics")
|
||||
@api.doc(description="Get daily message statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily message statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily message count data")),
|
||||
@ -52,32 +37,43 @@ class DailyMessageStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(*) AS message_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -95,11 +91,15 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-conversations")
|
||||
class DailyConversationStatistic(Resource):
|
||||
@console_ns.doc("get_daily_conversation_statistics")
|
||||
@console_ns.doc(description="Get daily conversation statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_daily_conversation_statistics")
|
||||
@api.doc(description="Get daily conversation statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily conversation statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily conversation count data")),
|
||||
@ -109,53 +109,63 @@ class DailyConversationStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
stmt = (
|
||||
sa.select(
|
||||
sa.func.date(
|
||||
sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
).label("date"),
|
||||
sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
|
||||
)
|
||||
.select_from(Message)
|
||||
.where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
|
||||
)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at >= start_datetime_utc)
|
||||
|
||||
if end_datetime_utc:
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
stmt = stmt.where(Message.created_at < end_datetime_utc)
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
stmt = stmt.group_by("date").order_by("date")
|
||||
|
||||
response_data = []
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
|
||||
rs = conn.execute(stmt, {"tz": account.timezone})
|
||||
for row in rs:
|
||||
response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/daily-end-users")
|
||||
class DailyTerminalsStatistic(Resource):
|
||||
@console_ns.doc("get_daily_terminals_statistics")
|
||||
@console_ns.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_daily_terminals_statistics")
|
||||
@api.doc(description="Get daily terminal/end-user statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily terminal statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily terminal count data")),
|
||||
@ -165,32 +175,43 @@ class DailyTerminalsStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT messages.from_end_user_id) AS terminal_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -208,11 +229,15 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/token-costs")
|
||||
class DailyTokenCostStatistic(Resource):
|
||||
@console_ns.doc("get_daily_token_cost_statistics")
|
||||
@console_ns.doc(description="Get daily token cost statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_daily_token_cost_statistics")
|
||||
@api.doc(description="Get daily token cost statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Daily token cost statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Daily token cost data")),
|
||||
@ -222,13 +247,15 @@ class DailyTokenCostStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
(SUM(messages.message_tokens) + SUM(messages.answer_tokens)) AS token_count,
|
||||
SUM(total_price) AS total_price
|
||||
FROM
|
||||
@ -236,19 +263,28 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -268,11 +304,15 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-session-interactions")
|
||||
class AverageSessionInteractionStatistic(Resource):
|
||||
@console_ns.doc("get_average_session_interaction_statistics")
|
||||
@console_ns.doc(description="Get average session interaction statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_average_session_interaction_statistics")
|
||||
@api.doc(description="Get average session interaction statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Average session interaction statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average session interaction data")),
|
||||
@ -282,13 +322,15 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("c.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(subquery.message_count) AS interactions
|
||||
FROM
|
||||
(
|
||||
@ -303,19 +345,28 @@ FROM
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND c.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND c.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -344,11 +395,15 @@ ORDER BY
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/user-satisfaction-rate")
|
||||
class UserSatisfactionRateStatistic(Resource):
|
||||
@console_ns.doc("get_user_satisfaction_rate_statistics")
|
||||
@console_ns.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_user_satisfaction_rate_statistics")
|
||||
@api.doc(description="Get user satisfaction rate statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"User satisfaction rate statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="User satisfaction rate data")),
|
||||
@ -358,13 +413,15 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("m.created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(m.id) AS message_count,
|
||||
COUNT(mf.id) AS feedback_count
|
||||
FROM
|
||||
@ -375,19 +432,28 @@ LEFT JOIN
|
||||
WHERE
|
||||
m.app_id = :app_id
|
||||
AND m.invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND m.created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND m.created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -410,11 +476,15 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/average-response-time")
|
||||
class AverageResponseTimeStatistic(Resource):
|
||||
@console_ns.doc("get_average_response_time_statistics")
|
||||
@console_ns.doc(description="Get average response time statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_average_response_time_statistics")
|
||||
@api.doc(description="Get average response time statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Average response time statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Average response time data")),
|
||||
@ -424,32 +494,43 @@ class AverageResponseTimeStatistic(Resource):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
AVG(provider_response_latency) AS latency
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
@ -467,11 +548,15 @@ WHERE
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/statistics/tokens-per-second")
|
||||
class TokensPerSecondStatistic(Resource):
|
||||
@console_ns.doc("get_tokens_per_second_statistics")
|
||||
@console_ns.doc(description="Get tokens per second statistics for an application")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[StatisticTimeRangeQuery.__name__])
|
||||
@console_ns.response(
|
||||
@api.doc("get_tokens_per_second_statistics")
|
||||
@api.doc(description="Get tokens per second statistics for an application")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.expect(
|
||||
api.parser()
|
||||
.add_argument("start", type=str, location="args", help="Start date (YYYY-MM-DD HH:MM)")
|
||||
.add_argument("end", type=str, location="args", help="End date (YYYY-MM-DD HH:MM)")
|
||||
)
|
||||
@api.response(
|
||||
200,
|
||||
"Tokens per second statistics retrieved successfully",
|
||||
fields.List(fields.Raw(description="Tokens per second data")),
|
||||
@ -481,12 +566,15 @@ class TokensPerSecondStatistic(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
account = current_user
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
CASE
|
||||
WHEN SUM(provider_response_latency) = 0 THEN 0
|
||||
ELSE (SUM(answer_tokens) / SUM(provider_response_latency))
|
||||
@ -496,19 +584,28 @@ FROM
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER.value}
|
||||
|
||||
try:
|
||||
start_datetime_utc, end_datetime_utc = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if end_datetime_utc:
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,85 +1,82 @@
|
||||
from datetime import datetime
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowAppLogQuery(BaseModel):
|
||||
keyword: str | None = Field(default=None, description="Search keyword for filtering logs")
|
||||
status: WorkflowExecutionStatus | None = Field(
|
||||
default=None, description="Execution status filter (succeeded, failed, stopped, partial-succeeded)"
|
||||
)
|
||||
created_at__before: datetime | None = Field(default=None, description="Filter logs created before this timestamp")
|
||||
created_at__after: datetime | None = Field(default=None, description="Filter logs created after this timestamp")
|
||||
created_by_end_user_session_id: str | None = Field(default=None, description="Filter by end user session ID")
|
||||
created_by_account: str | None = Field(default=None, description="Filter by account")
|
||||
detail: bool = Field(default=False, description="Whether to return detailed logs")
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
|
||||
@field_validator("created_at__before", "created_at__after", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value: str | None) -> datetime | None:
|
||||
if value in (None, ""):
|
||||
return None
|
||||
return isoparse(value) # type: ignore
|
||||
|
||||
@field_validator("detail", mode="before")
|
||||
@classmethod
|
||||
def parse_bool(cls, value: bool | str | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
lowered = value.lower()
|
||||
if lowered in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if lowered in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
raise ValueError("Invalid boolean value for detail")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowAppLogQuery.__name__, WorkflowAppLogQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
# Register model for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_app_log_pagination_model = build_workflow_app_log_pagination_model(console_ns)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-app-logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@console_ns.doc("get_workflow_app_logs")
|
||||
@console_ns.doc(description="Get workflow application execution logs")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowAppLogQuery.__name__])
|
||||
@console_ns.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_model)
|
||||
@api.doc("get_workflow_app_logs")
|
||||
@api.doc(description="Get workflow application execution logs")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(
|
||||
params={
|
||||
"keyword": "Search keyword for filtering logs",
|
||||
"status": "Filter by execution status (succeeded, failed, stopped, partial-succeeded)",
|
||||
"created_at__before": "Filter logs created before this timestamp",
|
||||
"created_at__after": "Filter logs created after this timestamp",
|
||||
"created_by_end_user_session_id": "Filter by end user session ID",
|
||||
"created_by_account": "Filter by account",
|
||||
"page": "Page number (1-99999)",
|
||||
"limit": "Number of items per page (1-100)",
|
||||
}
|
||||
)
|
||||
@api.response(200, "Workflow app logs retrieved successfully", workflow_app_log_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_app_log_pagination_model)
|
||||
@marshal_with(workflow_app_log_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow app logs
|
||||
"""
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@ -93,7 +90,6 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
detail=args.detail,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
@ -1,18 +1,17 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import NoReturn, ParamSpec, TypeVar
|
||||
from typing import NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
@ -22,8 +21,9 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -58,18 +58,16 @@ def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
@ -141,42 +139,8 @@ _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
workflow_draft_variable_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableWithoutValue", _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS
|
||||
)
|
||||
|
||||
workflow_draft_variable_model = console_ns.model("WorkflowDraftVariable", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_model = console_ns.model("WorkflowDraftEnvVariable", _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)
|
||||
|
||||
workflow_draft_env_variable_list_fields_copy = _WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_env_variable_list_fields_copy["items"] = fields.List(fields.Nested(workflow_draft_env_variable_model))
|
||||
workflow_draft_env_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftEnvVariableList", workflow_draft_env_variable_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_without_value_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS.copy()
|
||||
workflow_draft_variable_list_without_value_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_without_value_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_without_value_model = console_ns.model(
|
||||
"WorkflowDraftVariableListWithoutValue", workflow_draft_variable_list_without_value_fields_copy
|
||||
)
|
||||
|
||||
workflow_draft_variable_list_fields_copy = _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS.copy()
|
||||
workflow_draft_variable_list_fields_copy["items"] = fields.List(
|
||||
fields.Nested(workflow_draft_variable_model), attribute=_get_items
|
||||
)
|
||||
workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -190,10 +154,11 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -201,16 +166,13 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables")
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@console_ns.expect(_create_pagination_parser())
|
||||
@console_ns.doc("get_workflow_variables")
|
||||
@console_ns.doc(description="Get draft workflow variables")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.response(
|
||||
200, "Workflow variables retrieved successfully", workflow_draft_variable_list_without_value_model
|
||||
)
|
||||
@api.doc("get_workflow_variables")
|
||||
@api.doc(description="Get draft workflow variables")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"page": "Page number (1-100000)", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_without_value_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
@ -237,9 +199,9 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@console_ns.doc("delete_workflow_variables")
|
||||
@console_ns.doc(description="Delete all draft workflow variables")
|
||||
@console_ns.response(204, "Workflow variables deleted successfully")
|
||||
@api.doc("delete_workflow_variables")
|
||||
@api.doc(description="Delete all draft workflow variables")
|
||||
@api.response(204, "Workflow variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -270,12 +232,12 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_node_variables")
|
||||
@console_ns.doc(description="Get variables for a specific node")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@api.doc("get_node_variables")
|
||||
@api.doc(description="Get variables for a specific node")
|
||||
@api.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@api.response(200, "Node variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
@ -286,9 +248,9 @@ class NodeVariableCollectionApi(Resource):
|
||||
|
||||
return node_vars
|
||||
|
||||
@console_ns.doc("delete_node_variables")
|
||||
@console_ns.doc(description="Delete all variables for a specific node")
|
||||
@console_ns.response(204, "Node variables deleted successfully")
|
||||
@api.doc("delete_node_variables")
|
||||
@api.doc(description="Delete all variables for a specific node")
|
||||
@api.response(204, "Node variables deleted successfully")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
@ -303,13 +265,13 @@ class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@console_ns.doc("get_variable")
|
||||
@console_ns.doc(description="Get a specific workflow variable")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable retrieved successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@api.doc("get_variable")
|
||||
@api.doc(description="Get a specific workflow variable")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
@ -321,10 +283,10 @@ class VariableApi(Resource):
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@console_ns.doc("update_variable")
|
||||
@console_ns.doc(description="Update a workflow variable")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_variable")
|
||||
@api.doc(description="Update a workflow variable")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateVariableRequest",
|
||||
{
|
||||
"name": fields.String(description="Variable name"),
|
||||
@ -332,10 +294,10 @@ class VariableApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Variable updated successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@api.response(200, "Variable updated successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
@ -358,11 +320,10 @@ class VariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
@ -397,10 +358,10 @@ class VariableApi(Resource):
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@console_ns.doc("delete_variable")
|
||||
@console_ns.doc(description="Delete a workflow variable")
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@api.doc("delete_variable")
|
||||
@api.doc(description="Delete a workflow variable")
|
||||
@api.response(204, "Variable deleted successfully")
|
||||
@api.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -418,12 +379,12 @@ class VariableApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class VariableResetApi(Resource):
|
||||
@console_ns.doc("reset_variable")
|
||||
@console_ns.doc(description="Reset a workflow variable to its default value")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@console_ns.response(200, "Variable reset successfully", workflow_draft_variable_model)
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@api.doc("reset_variable")
|
||||
@api.doc(description="Reset a workflow variable to its default value")
|
||||
@api.doc(params={"app_id": "Application ID", "variable_id": "Variable ID"})
|
||||
@api.response(200, "Variable reset successfully", _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@api.response(204, "Variable reset (no content)")
|
||||
@api.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
@ -447,7 +408,7 @@ class VariableResetApi(Resource):
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, workflow_draft_variable_model)
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
@ -466,13 +427,13 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_conversation_variables")
|
||||
@console_ns.doc(description="Get conversation variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@api.doc("get_conversation_variables")
|
||||
@api.doc(description="Get conversation variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Conversation variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
@ -488,23 +449,23 @@ class ConversationVariableCollectionApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_system_variables")
|
||||
@console_ns.doc(description="Get system variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "System variables retrieved successfully", workflow_draft_variable_list_model)
|
||||
@api.doc("get_system_variables")
|
||||
@api.doc(description="Get system variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "System variables retrieved successfully", _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@console_ns.doc("get_environment_variables")
|
||||
@console_ns.doc(description="Get environment variables for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables retrieved successfully")
|
||||
@console_ns.response(404, "Draft workflow not found")
|
||||
@api.doc("get_environment_variables")
|
||||
@api.doc(description="Get environment variables for workflow")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.response(200, "Environment variables retrieved successfully")
|
||||
@api.response(404, "Draft workflow not found")
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
|
||||
@ -1,341 +1,90 @@
|
||||
from typing import Literal, cast
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.workflow_run_fields import (
|
||||
advanced_chat_workflow_run_for_list_fields,
|
||||
advanced_chat_workflow_run_pagination_fields,
|
||||
workflow_run_count_fields,
|
||||
workflow_run_detail_fields,
|
||||
workflow_run_for_list_fields,
|
||||
workflow_run_node_execution_fields,
|
||||
workflow_run_node_execution_list_fields,
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowRunTriggeredFrom
|
||||
from libs.login import login_required
|
||||
from models import Account, App, AppMode, EndUser
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
|
||||
# Base models
|
||||
simple_account_model = console_ns.model("SimpleAccount", simple_account_fields)
|
||||
|
||||
simple_end_user_model = console_ns.model("SimpleEndUser", simple_end_user_fields)
|
||||
|
||||
# Models that depend on simple_account_fields
|
||||
workflow_run_for_list_fields_copy = workflow_run_for_list_fields.copy()
|
||||
workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_for_list_model = console_ns.model("WorkflowRunForList", workflow_run_for_list_fields_copy)
|
||||
|
||||
advanced_chat_workflow_run_for_list_fields_copy = advanced_chat_workflow_run_for_list_fields.copy()
|
||||
advanced_chat_workflow_run_for_list_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
advanced_chat_workflow_run_for_list_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunForList", advanced_chat_workflow_run_for_list_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_detail_fields_copy = workflow_run_detail_fields.copy()
|
||||
workflow_run_detail_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_detail_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_detail_model = console_ns.model("WorkflowRunDetail", workflow_run_detail_fields_copy)
|
||||
|
||||
workflow_run_node_execution_fields_copy = workflow_run_node_execution_fields.copy()
|
||||
workflow_run_node_execution_fields_copy["created_by_account"] = fields.Nested(
|
||||
simple_account_model, attribute="created_by_account", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_fields_copy["created_by_end_user"] = fields.Nested(
|
||||
simple_end_user_model, attribute="created_by_end_user", allow_null=True
|
||||
)
|
||||
workflow_run_node_execution_model = console_ns.model(
|
||||
"WorkflowRunNodeExecution", workflow_run_node_execution_fields_copy
|
||||
)
|
||||
|
||||
# Simple models without nested dependencies
|
||||
workflow_run_count_model = console_ns.model("WorkflowRunCount", workflow_run_count_fields)
|
||||
|
||||
# Pagination models that depend on list models
|
||||
advanced_chat_workflow_run_pagination_fields_copy = advanced_chat_workflow_run_pagination_fields.copy()
|
||||
advanced_chat_workflow_run_pagination_fields_copy["data"] = fields.List(
|
||||
fields.Nested(advanced_chat_workflow_run_for_list_model), attribute="data"
|
||||
)
|
||||
advanced_chat_workflow_run_pagination_model = console_ns.model(
|
||||
"AdvancedChatWorkflowRunPagination", advanced_chat_workflow_run_pagination_fields_copy
|
||||
)
|
||||
|
||||
workflow_run_pagination_fields_copy = workflow_run_pagination_fields.copy()
|
||||
workflow_run_pagination_fields_copy["data"] = fields.List(fields.Nested(workflow_run_for_list_model), attribute="data")
|
||||
workflow_run_pagination_model = console_ns.model("WorkflowRunPagination", workflow_run_pagination_fields_copy)
|
||||
|
||||
workflow_run_node_execution_list_fields_copy = workflow_run_node_execution_list_fields.copy()
|
||||
workflow_run_node_execution_list_fields_copy["data"] = fields.List(fields.Nested(workflow_run_node_execution_model))
|
||||
workflow_run_node_execution_list_model = console_ns.model(
|
||||
"WorkflowRunNodeExecutionList", workflow_run_node_execution_list_fields_copy
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunListQuery(BaseModel):
|
||||
last_id: str | None = Field(default=None, description="Last run ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of items per page (1-100)")
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("last_id")
|
||||
@classmethod
|
||||
def validate_last_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class WorkflowRunCountQuery(BaseModel):
|
||||
status: Literal["running", "succeeded", "failed", "stopped", "partial-succeeded"] | None = Field(
|
||||
default=None, description="Workflow run status filter"
|
||||
)
|
||||
time_range: str | None = Field(default=None, description="Time range filter (e.g., 7d, 4h, 30m, 30s)")
|
||||
triggered_from: Literal["debugging", "app-run"] | None = Field(
|
||||
default=None, description="Filter by trigger source: debugging or app-run"
|
||||
)
|
||||
|
||||
@field_validator("time_range")
|
||||
@classmethod
|
||||
def validate_time_range(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return time_duration(value)
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowRunListQuery.__name__, WorkflowRunListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
WorkflowRunCountQuery.__name__,
|
||||
WorkflowRunCountQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs")
|
||||
class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs")
|
||||
@console_ns.doc(description="Get advanced chat workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_model)
|
||||
@api.doc("get_advanced_chat_workflow_runs")
|
||||
@api.doc(description="Get advanced chat workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", advanced_chat_workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflow-runs/count")
|
||||
class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_workflow_runs_count")
|
||||
@console_ns.doc(description="Get advanced chat workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs")
|
||||
class WorkflowRunListApi(Resource):
|
||||
@console_ns.doc("get_workflow_runs")
|
||||
@console_ns.doc(description="Get workflow run list")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunListQuery.__name__])
|
||||
@api.doc("get_workflow_runs")
|
||||
@api.doc(description="Get workflow run list")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"last_id": "Last run ID for pagination", "limit": "Number of items per page (1-100)"})
|
||||
@api.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_fields)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
@marshal_with(workflow_run_pagination_fields)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("last_id", type=uuid_value, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_paginate_workflow_runs(
|
||||
app_model=app_model, args=args, triggered_from=triggered_from
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/count")
|
||||
class WorkflowRunCountApi(Resource):
|
||||
@console_ns.doc("get_workflow_runs_count")
|
||||
@console_ns.doc(description="Get workflow runs count statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.doc(
|
||||
params={"status": "Filter by status (optional): running, succeeded, failed, stopped, partial-succeeded"}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"time_range": (
|
||||
"Filter by time range (optional): e.g., 7d (7 days), 4h (4 hours), "
|
||||
"30m (30 minutes), 30s (30 seconds). Filters by created_at field."
|
||||
)
|
||||
}
|
||||
)
|
||||
@console_ns.doc(
|
||||
params={"triggered_from": "Filter by trigger source (optional): debugging or app-run. Default: debugging"}
|
||||
)
|
||||
@console_ns.response(200, "Workflow runs count retrieved successfully", workflow_run_count_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowRunCountQuery.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow runs count statistics
|
||||
"""
|
||||
args_model = WorkflowRunCountQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom(args_model.triggered_from)
|
||||
if args_model.triggered_from
|
||||
else WorkflowRunTriggeredFrom.DEBUGGING
|
||||
)
|
||||
|
||||
workflow_run_service = WorkflowRunService()
|
||||
result = workflow_run_service.get_workflow_runs_count(
|
||||
app_model=app_model,
|
||||
status=args.get("status"),
|
||||
time_range=args.get("time_range"),
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>")
|
||||
class WorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_detail")
|
||||
@console_ns.doc(description="Get workflow run detail")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@api.doc("get_workflow_run_detail")
|
||||
@api.doc(description="Get workflow run detail")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
@marshal_with(workflow_run_detail_fields)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run detail
|
||||
@ -350,16 +99,16 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@console_ns.doc("get_workflow_run_node_executions")
|
||||
@console_ns.doc(description="Get workflow run node execution list")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@api.doc("get_workflow_run_node_executions")
|
||||
@api.doc(description="Get workflow run node execution list")
|
||||
@api.doc(params={"app_id": "Application ID", "run_id": "Workflow run ID"})
|
||||
@api.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_fields)
|
||||
@api.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
@marshal_with(workflow_run_node_execution_list_fields)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
|
||||
@ -1,194 +1,311 @@
|
||||
from flask import abort, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from controllers.console import console_ns
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowStatisticQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date and time (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date and time (YYYY-MM-DD HH:MM)")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
def blank_to_none(cls, value: str | None) -> str | None:
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowStatisticQuery.__name__,
|
||||
WorkflowStatisticQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-conversations")
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_runs_statistic")
|
||||
@console_ns.doc(description="Get workflow daily runs statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily runs statistics retrieved successfully")
|
||||
@api.doc("get_workflow_daily_runs_statistic")
|
||||
@api.doc(description="Get workflow daily runs statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily runs statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(id) AS runs
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_runs_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "runs": i.runs})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/daily-terminals")
|
||||
class WorkflowDailyTerminalsStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_terminals_statistic")
|
||||
@console_ns.doc(description="Get workflow daily terminals statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@api.doc("get_workflow_daily_terminals_statistic")
|
||||
@api.doc(description="Get workflow daily terminals statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily terminals statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
COUNT(DISTINCT workflow_runs.created_by) AS terminal_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_terminals_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append({"date": str(i.date), "terminal_count": i.terminal_count})
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/token-costs")
|
||||
class WorkflowDailyTokenCostStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_daily_token_cost_statistic")
|
||||
@console_ns.doc(description="Get workflow daily token cost statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@api.doc("get_workflow_daily_token_cost_statistic")
|
||||
@api.doc(description="Get workflow daily token cost statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Daily token cost statistics retrieved successfully")
|
||||
@get_app_model
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
sql_query = """SELECT
|
||||
DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
SUM(workflow_runs.total_tokens) AS token_count
|
||||
FROM
|
||||
workflow_runs
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND triggered_from = :triggered_from"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
response_data = self._workflow_run_repo.get_daily_token_cost_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at >= :start"
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query += " AND created_at < :end"
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
|
||||
sql_query += " GROUP BY date ORDER BY date"
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{
|
||||
"date": str(i.date),
|
||||
"token_count": i.token_count,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/statistics/average-app-interactions")
|
||||
class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
@console_ns.doc("get_workflow_average_app_interaction_statistic")
|
||||
@console_ns.doc(description="Get workflow average app interaction statistics")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowStatisticQuery.__name__])
|
||||
@console_ns.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@api.doc("get_workflow_average_app_interaction_statistic")
|
||||
@api.doc(description="Get workflow average app interaction statistics")
|
||||
@api.doc(params={"app_id": "Application ID"})
|
||||
@api.doc(params={"start": "Start date and time (YYYY-MM-DD HH:MM)", "end": "End date and time (YYYY-MM-DD HH:MM)"})
|
||||
@api.response(200, "Average app interaction statistics retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def get(self, app_model):
|
||||
account, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
|
||||
args = WorkflowStatisticQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert account.timezone is not None
|
||||
sql_query = """SELECT
|
||||
AVG(sub.interactions) AS interactions,
|
||||
sub.date
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
DATE(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
c.created_by,
|
||||
COUNT(c.id) AS interactions
|
||||
FROM
|
||||
workflow_runs c
|
||||
WHERE
|
||||
c.app_id = :app_id
|
||||
AND c.triggered_from = :triggered_from
|
||||
{{start}}
|
||||
{{end}}
|
||||
GROUP BY
|
||||
date, c.created_by
|
||||
) sub
|
||||
GROUP BY
|
||||
sub.date"""
|
||||
arg_dict = {
|
||||
"tz": account.timezone,
|
||||
"app_id": app_model.id,
|
||||
"triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
}
|
||||
|
||||
try:
|
||||
start_date, end_date = parse_time_range(args.start, args.end, account.timezone)
|
||||
except ValueError as e:
|
||||
abort(400, description=str(e))
|
||||
timezone = pytz.timezone(account.timezone)
|
||||
utc_timezone = pytz.utc
|
||||
|
||||
response_data = self._workflow_run_repo.get_average_app_interaction_statistics(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
timezone=account.timezone,
|
||||
)
|
||||
if args["start"]:
|
||||
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
|
||||
start_datetime = start_datetime.replace(second=0)
|
||||
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start")
|
||||
arg_dict["start"] = start_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{start}}", "")
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
end_datetime = end_datetime.replace(second=0)
|
||||
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
sql_query = sql_query.replace("{{end}}", " AND c.created_at < :end")
|
||||
arg_dict["end"] = end_datetime_utc
|
||||
else:
|
||||
sql_query = sql_query.replace("{{end}}", "")
|
||||
|
||||
response_data = []
|
||||
|
||||
with db.engine.begin() as conn:
|
||||
rs = conn.execute(sa.text(sql_query), arg_dict)
|
||||
for i in rs:
|
||||
response_data.append(
|
||||
{"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))}
|
||||
)
|
||||
|
||||
return jsonify({"data": response_data})
|
||||
|
||||
@ -1,157 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_trigger_fields import trigger_fields, triggers_list_fields, webhook_trigger_fields
|
||||
from libs.login import current_user, login_required
|
||||
from models.enums import AppTriggerStatus
|
||||
from models.model import Account, App, AppMode
|
||||
from models.trigger import AppTrigger, WorkflowWebhookTrigger
|
||||
|
||||
from .. import console_ns
|
||||
from ..app.wraps import get_app_model
|
||||
from ..wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class ParserEnable(BaseModel):
|
||||
trigger_id: str
|
||||
enable_trigger: bool
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
console_ns.schema_model(
|
||||
ParserEnable.__name__, ParserEnable.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/triggers/webhook")
|
||||
class WebhookTriggerApi(Resource):
|
||||
"""Webhook Trigger API"""
|
||||
|
||||
@console_ns.expect(console_ns.models[Parser.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(webhook_trigger_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get webhook trigger for a node"""
|
||||
args = Parser.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
raise NotFound("Webhook trigger not found for this node")
|
||||
|
||||
return webhook_trigger
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/triggers")
|
||||
class AppTriggersApi(Resource):
|
||||
"""App Triggers list API"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(triggers_list_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Get app triggers list"""
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
select(AppTrigger)
|
||||
.where(
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
.order_by(AppTrigger.created_at.desc(), AppTrigger.id.desc())
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Add computed icon field for each trigger
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
for trigger in triggers:
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return {"data": triggers}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/trigger-enable")
|
||||
class AppTriggerEnableApi(Resource):
|
||||
@console_ns.expect(console_ns.models[ParserEnable.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.WORKFLOW)
|
||||
@marshal_with(trigger_fields)
|
||||
def post(self, app_model: App):
|
||||
"""Update app trigger (enable/disable)"""
|
||||
args = ParserEnable.model_validate(console_ns.payload)
|
||||
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with Session(db.engine) as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
AppTrigger.id == trigger_id,
|
||||
AppTrigger.tenant_id == current_user.current_tenant_id,
|
||||
AppTrigger.app_id == app_model.id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not trigger:
|
||||
raise NotFound("Trigger not found")
|
||||
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
trigger.icon = url_prefix + trigger.provider_name + "/icon" # type: ignore
|
||||
else:
|
||||
trigger.icon = "" # type: ignore
|
||||
|
||||
return trigger
|
||||
@ -4,29 +4,28 @@ from typing import ParamSpec, TypeVar, Union
|
||||
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
assert isinstance(current_user, Account)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@ -2,31 +2,35 @@ from flask import request
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models import AccountStatus
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
active_check_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="args", help="Email address")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="args", help="Activation token")
|
||||
active_check_parser = reqparse.RequestParser()
|
||||
active_check_parser.add_argument(
|
||||
"workspace_id", type=str, required=False, nullable=True, location="args", help="Workspace ID"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"email", type=email, required=False, nullable=True, location="args", help="Email address"
|
||||
)
|
||||
active_check_parser.add_argument(
|
||||
"token", type=str, required=True, nullable=False, location="args", help="Activation token"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/activate/check")
|
||||
class ActivateCheckApi(Resource):
|
||||
@console_ns.doc("check_activation_token")
|
||||
@console_ns.doc(description="Check if activation token is valid")
|
||||
@console_ns.expect(active_check_parser)
|
||||
@console_ns.response(
|
||||
@api.doc("check_activation_token")
|
||||
@api.doc(description="Check if activation token is valid")
|
||||
@api.expect(active_check_parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.model(
|
||||
api.model(
|
||||
"ActivationCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether token is valid"),
|
||||
@ -56,26 +60,26 @@ class ActivateCheckApi(Resource):
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
active_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
.add_argument("interface_language", type=supported_language, required=True, nullable=False, location="json")
|
||||
.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
active_parser = reqparse.RequestParser()
|
||||
active_parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
active_parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
active_parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
active_parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
active_parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/activate")
|
||||
class ActivateApi(Resource):
|
||||
@console_ns.doc("activate_account")
|
||||
@console_ns.doc(description="Activate account with invitation token")
|
||||
@console_ns.expect(active_parser)
|
||||
@console_ns.response(
|
||||
@api.doc("activate_account")
|
||||
@api.doc(description="Activate account with invitation token")
|
||||
@api.expect(active_parser)
|
||||
@api.response(
|
||||
200,
|
||||
"Account activated successfully",
|
||||
console_ns.model(
|
||||
api.model(
|
||||
"ActivationResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
@ -83,7 +87,7 @@ class ActivateApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@console_ns.response(400, "Already activated or invalid token")
|
||||
@api.response(400, "Already activated or invalid token")
|
||||
def post(self):
|
||||
args = active_parser.parse_args()
|
||||
|
||||
@ -99,7 +103,7 @@ class ActivateApi(Resource):
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
@ -15,8 +16,7 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
@ -39,20 +39,18 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("category", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, args)
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
@ -63,11 +61,11 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@is_admin_or_owner_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -2,11 +2,12 @@ import logging
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import is_admin_or_owner_required
|
||||
from controllers.console import api, console_ns
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
@ -29,22 +30,23 @@ def get_oauth_providers():
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>")
|
||||
class OAuthDataSource(Resource):
|
||||
@console_ns.doc("oauth_data_source")
|
||||
@console_ns.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@console_ns.response(
|
||||
@api.doc("oauth_data_source")
|
||||
@api.doc(description="Get OAuth authorization URL for data source provider")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)"})
|
||||
@api.response(
|
||||
200,
|
||||
"Authorization URL or internal setup success",
|
||||
console_ns.model(
|
||||
api.model(
|
||||
"OAuthDataSourceResponse",
|
||||
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
|
||||
),
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@is_admin_or_owner_required
|
||||
@api.response(400, "Invalid provider")
|
||||
@api.response(403, "Admin privileges required")
|
||||
def get(self, provider: str):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
|
||||
@ -63,17 +65,17 @@ class OAuthDataSource(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/callback/<string:provider>")
|
||||
class OAuthDataSourceCallback(Resource):
|
||||
@console_ns.doc("oauth_data_source_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback from data source provider")
|
||||
@console_ns.doc(
|
||||
@api.doc("oauth_data_source_callback")
|
||||
@api.doc(description="Handle OAuth callback from data source provider")
|
||||
@api.doc(
|
||||
params={
|
||||
"provider": "Data source provider name (notion)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"error": "Error message from OAuth provider",
|
||||
}
|
||||
)
|
||||
@console_ns.response(302, "Redirect to console with result")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@api.response(302, "Redirect to console with result")
|
||||
@api.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -94,17 +96,17 @@ class OAuthDataSourceCallback(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/binding/<string:provider>")
|
||||
class OAuthDataSourceBinding(Resource):
|
||||
@console_ns.doc("oauth_data_source_binding")
|
||||
@console_ns.doc(description="Bind OAuth data source with authorization code")
|
||||
@console_ns.doc(
|
||||
@api.doc("oauth_data_source_binding")
|
||||
@api.doc(description="Bind OAuth data source with authorization code")
|
||||
@api.doc(
|
||||
params={"provider": "Data source provider name (notion)", "code": "Authorization code from OAuth provider"}
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
200,
|
||||
"Data source binding success",
|
||||
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
api.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or code")
|
||||
@api.response(400, "Invalid provider or code")
|
||||
def get(self, provider: str):
|
||||
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -128,15 +130,15 @@ class OAuthDataSourceBinding(Resource):
|
||||
|
||||
@console_ns.route("/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
|
||||
class OAuthDataSourceSync(Resource):
|
||||
@console_ns.doc("oauth_data_source_sync")
|
||||
@console_ns.doc(description="Sync data from OAuth data source")
|
||||
@console_ns.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@console_ns.response(
|
||||
@api.doc("oauth_data_source_sync")
|
||||
@api.doc(description="Sync data from OAuth data source")
|
||||
@api.doc(params={"provider": "Data source provider name (notion)", "binding_id": "Data source binding ID"})
|
||||
@api.response(
|
||||
200,
|
||||
"Data source sync success",
|
||||
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
api.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@console_ns.response(400, "Invalid provider or sync failed")
|
||||
@api.response(400, "Invalid provider or sync failed")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -19,7 +19,7 @@ from controllers.console.wraps import email_password_login_enabled, email_regist
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
@ -31,11 +31,9 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -61,12 +59,10 @@ class EmailRegisterCheckApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -104,12 +100,10 @@ class EmailRegisterResetApi(Resource):
|
||||
@email_password_login_enabled
|
||||
@email_register_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
|
||||
@ -6,7 +6,7 @@ from flask_restx import Resource, fields, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
@ -20,17 +20,17 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@console_ns.route("/forgot-password")
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@console_ns.doc("send_forgot_password_email")
|
||||
@console_ns.doc(description="Send password reset email")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("send_forgot_password_email")
|
||||
@api.doc(description="Send password reset email")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordEmailRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
@ -38,10 +38,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
200,
|
||||
"Email sent successfully",
|
||||
console_ns.model(
|
||||
api.model(
|
||||
"ForgotPasswordEmailResponse",
|
||||
{
|
||||
"result": fields.String(description="Operation result"),
|
||||
@ -50,15 +50,13 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@console_ns.response(400, "Invalid email or rate limit exceeded")
|
||||
@api.response(400, "Invalid email or rate limit exceeded")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -85,10 +83,10 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
|
||||
@console_ns.route("/forgot-password/validity")
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@console_ns.doc("check_forgot_password_code")
|
||||
@console_ns.doc(description="Verify password reset code")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("check_forgot_password_code")
|
||||
@api.doc(description="Verify password reset code")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordCheckRequest",
|
||||
{
|
||||
"email": fields.String(required=True, description="Email address"),
|
||||
@ -97,10 +95,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
200,
|
||||
"Code verified successfully",
|
||||
console_ns.model(
|
||||
api.model(
|
||||
"ForgotPasswordCheckResponse",
|
||||
{
|
||||
"is_valid": fields.Boolean(description="Whether code is valid"),
|
||||
@ -109,16 +107,14 @@ class ForgotPasswordCheckApi(Resource):
|
||||
},
|
||||
),
|
||||
)
|
||||
@console_ns.response(400, "Invalid code or token")
|
||||
@api.response(400, "Invalid code or token")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
@ -152,10 +148,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
@console_ns.route("/forgot-password/resets")
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@console_ns.doc("reset_password")
|
||||
@console_ns.doc(description="Reset password with verification token")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("reset_password")
|
||||
@api.doc(description="Reset password with verification token")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ForgotPasswordResetRequest",
|
||||
{
|
||||
"token": fields.String(required=True, description="Verification token"),
|
||||
@ -164,21 +160,19 @@ class ForgotPasswordResetApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(
|
||||
@api.response(
|
||||
200,
|
||||
"Password reset successfully",
|
||||
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
api.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
|
||||
)
|
||||
@console_ns.response(400, "Invalid token or password mismatch")
|
||||
@api.response(400, "Invalid token or password mismatch")
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from typing import cast
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants.languages import get_valid_language
|
||||
from constants.languages import languages
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
@ -24,16 +26,7 @@ from controllers.console.error import (
|
||||
from controllers.console.wraps import email_password_login_enabled, setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_refresh_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -49,13 +42,11 @@ class LoginApi(Resource):
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("password", type=str, required=True, location="json")
|
||||
.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=str, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args["email"]):
|
||||
@ -98,36 +89,19 @@ class LoginApi(Resource):
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
|
||||
return response
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Clear cookies on logout
|
||||
clear_access_token_from_cookie(response)
|
||||
clear_refresh_token_from_cookie(response)
|
||||
clear_csrf_token_from_cookie(response)
|
||||
|
||||
return response
|
||||
return {"result": "success"}
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
@ -135,11 +109,9 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
@ -165,11 +137,9 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=email, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
@ -200,17 +170,13 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("email", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=True, location="json")
|
||||
.add_argument("token", type=str, required=True, location="json")
|
||||
.add_argument("language", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
language = args["language"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@ -244,9 +210,7 @@ class EmailCodeLoginApi(Resource):
|
||||
if account is None:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
@ -256,36 +220,18 @@ class EmailCodeLoginApi(Resource):
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
# Set HTTP-only secure cookies for tokens
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
return response
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = extract_refresh_token(request)
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Update cookies with new tokens
|
||||
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||
return response
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
@ -14,19 +14,15 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models import Account, AccountStatus
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import console_ns
|
||||
from .. import api, console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -56,13 +52,11 @@ def get_oauth_providers():
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@console_ns.doc("oauth_login")
|
||||
@console_ns.doc(description="Initiate OAuth login process")
|
||||
@console_ns.doc(
|
||||
params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"}
|
||||
)
|
||||
@console_ns.response(302, "Redirect to OAuth authorization URL")
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
@api.doc("oauth_login")
|
||||
@api.doc(description="Initiate OAuth login process")
|
||||
@api.doc(params={"provider": "OAuth provider name (github/google)", "invite_token": "Optional invitation token"})
|
||||
@api.response(302, "Redirect to OAuth authorization URL")
|
||||
@api.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
@ -77,17 +71,17 @@ class OAuthLogin(Resource):
|
||||
|
||||
@console_ns.route("/oauth/authorize/<provider>")
|
||||
class OAuthCallback(Resource):
|
||||
@console_ns.doc("oauth_callback")
|
||||
@console_ns.doc(description="Handle OAuth callback and complete login process")
|
||||
@console_ns.doc(
|
||||
@api.doc("oauth_callback")
|
||||
@api.doc(description="Handle OAuth callback and complete login process")
|
||||
@api.doc(
|
||||
params={
|
||||
"provider": "OAuth provider name (github/google)",
|
||||
"code": "Authorization code from OAuth provider",
|
||||
"state": "Optional state parameter (used for invite token)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(302, "Redirect to console with access token")
|
||||
@console_ns.response(400, "OAuth process failed")
|
||||
@api.response(302, "Redirect to console with access token")
|
||||
@api.response(400, "OAuth process failed")
|
||||
def get(self, provider: str):
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
@ -136,11 +130,11 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
|
||||
|
||||
# Check account status
|
||||
if account.status == AccountStatus.BANNED:
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING:
|
||||
account.status = AccountStatus.ACTIVE
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -159,12 +153,9 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Account
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
@ -23,7 +24,8 @@ T = TypeVar("T")
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser().add_argument("client_id", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
@ -89,7 +91,8 @@ class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser().add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
@ -113,8 +116,7 @@ class OAuthServerUserAuthorizeApi(Resource):
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
@ -130,14 +132,12 @@ class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("grant_type", type=str, required=True, location="json")
|
||||
.add_argument("code", type=str, required=False, location="json")
|
||||
.add_argument("client_secret", type=str, required=False, location="json")
|
||||
.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
import base64
|
||||
|
||||
from flask_restx import Resource, fields, reqparse
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -17,21 +14,17 @@ class Subscription(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"plan",
|
||||
type=str,
|
||||
required=True,
|
||||
location="args",
|
||||
choices=[CloudPlan.PROFESSIONAL, CloudPlan.TEAM],
|
||||
)
|
||||
.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args["plan"], args["interval"], current_user.email, current_tenant_id)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/billing/invoices")
|
||||
@ -41,40 +34,7 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
assert isinstance(current_user, Account)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_invoices(current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@console_ns.route("/billing/partners/<string:partner_key>/tenants")
|
||||
class PartnerTenants(Resource):
|
||||
@console_ns.doc("sync_partner_tenants_bindings")
|
||||
@console_ns.doc(description="Sync partner tenants bindings")
|
||||
@console_ns.doc(params={"partner_key": "Partner key"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
"SyncPartnerTenantsBindingsRequest",
|
||||
{"click_id": fields.String(required=True, description="Click Id from partner referral link")},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Tenants synced to partner successfully")
|
||||
@console_ns.response(400, "Invalid partner information")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def put(self, partner_key: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("click_id", required=True, type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
click_id = args["click_id"]
|
||||
decoded_partner_key = base64.b64decode(partner_key).decode("utf-8")
|
||||
except Exception:
|
||||
raise BadRequest("Invalid partner_key")
|
||||
|
||||
if not click_id or not decoded_partner_key or not current_user.id:
|
||||
raise BadRequest("Invalid partner information")
|
||||
|
||||
return BillingService.sync_partner_tenants_bindings(current_user.id, decoded_partner_key, click_id)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import console_ns
|
||||
@ -16,16 +17,17 @@ class ComplianceApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("doc_name", type=str, required=True, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
return BillingService.get_compliance_download_link(
|
||||
doc_name=args.doc_name,
|
||||
account_id=current_user.id,
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -14,12 +15,12 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
@ -36,12 +37,10 @@ class DataSourceApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == current_tenant_id,
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
)
|
||||
).all()
|
||||
@ -121,15 +120,13 @@ class DataSourceNotionListApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@ -149,7 +146,7 @@ class DataSourceNotionListApi(Resource):
|
||||
documents = session.scalars(
|
||||
select(Document).filter_by(
|
||||
dataset_id=dataset_id,
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
data_source_type="notion_import",
|
||||
enabled=True,
|
||||
)
|
||||
@ -164,7 +161,7 @@ class DataSourceNotionListApi(Resource):
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -213,14 +210,12 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@ -234,7 +229,7 @@ class DataSourceNotionApi(Resource):
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
@ -244,14 +239,12 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# validate args
|
||||
@ -263,22 +256,20 @@ class DataSourceNotionApi(Resource):
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
document_model=args["doc_form"],
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
indexing_runner = IndexingRunner()
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_tenant_id,
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
args["process_rule"],
|
||||
args["doc_form"],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -6,12 +6,13 @@ from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
@ -43,19 +44,18 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
document_metadata_fields,
|
||||
document_status_fields,
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
@ -63,39 +63,8 @@ from services.entities.knowledge_entities.knowledge_entities import KnowledgeCon
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_model = _get_or_create_model("Dataset", dataset_fields)
|
||||
|
||||
document_metadata_model = _get_or_create_model("DocumentMetadata", document_metadata_fields)
|
||||
|
||||
document_fields_copy = document_fields.copy()
|
||||
document_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_model = _get_or_create_model("Document", document_fields_copy)
|
||||
|
||||
document_with_segments_fields_copy = document_with_segments_fields.copy()
|
||||
document_with_segments_fields_copy["doc_metadata"] = fields.List(
|
||||
fields.Nested(document_metadata_model), attribute="doc_metadata_details"
|
||||
)
|
||||
document_with_segments_model = _get_or_create_model("DocumentWithSegments", document_with_segments_fields_copy)
|
||||
|
||||
dataset_and_document_fields_copy = dataset_and_document_fields.copy()
|
||||
dataset_and_document_fields_copy["dataset"] = fields.Nested(dataset_model)
|
||||
dataset_and_document_fields_copy["documents"] = fields.List(fields.Nested(document_model))
|
||||
dataset_and_document_model = _get_or_create_model("DatasetAndDocument", dataset_and_document_fields_copy)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -110,13 +79,12 @@ class DocumentResource(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
if document.tenant_id != current_tenant_id:
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -136,15 +104,14 @@ class DocumentResource(Resource):
|
||||
|
||||
@console_ns.route("/datasets/process-rule")
|
||||
class GetProcessRuleApi(Resource):
|
||||
@console_ns.doc("get_process_rule")
|
||||
@console_ns.doc(description="Get dataset document processing rules")
|
||||
@console_ns.doc(params={"document_id": "Document ID (optional)"})
|
||||
@console_ns.response(200, "Process rules retrieved successfully")
|
||||
@api.doc("get_process_rule")
|
||||
@api.doc(description="Get dataset document processing rules")
|
||||
@api.doc(params={"document_id": "Document ID (optional)"})
|
||||
@api.response(200, "Process rules retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -184,9 +151,9 @@ class GetProcessRuleApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
class DatasetDocumentListApi(Resource):
|
||||
@console_ns.doc("get_dataset_documents")
|
||||
@console_ns.doc(description="Get documents in a dataset")
|
||||
@console_ns.doc(
|
||||
@api.doc("get_dataset_documents")
|
||||
@api.doc(description="Get documents in a dataset")
|
||||
@api.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"page": "Page number (default: 1)",
|
||||
@ -194,20 +161,18 @@ class DatasetDocumentListApi(Resource):
|
||||
"keyword": "Search keyword",
|
||||
"sort": "Sort order (default: -created_at)",
|
||||
"fetch": "Fetch full details (default: false)",
|
||||
"status": "Filter documents by display status",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Documents retrieved successfully")
|
||||
@api.response(200, "Documents retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
def get(self, dataset_id):
|
||||
dataset_id = str(dataset_id)
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
sort = request.args.get("sort", default="-created_at", type=str)
|
||||
status = request.args.get("status", default=None, type=str)
|
||||
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
|
||||
try:
|
||||
fetch_val = request.args.get("fetch", default="false")
|
||||
@ -234,10 +199,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_tenant_id)
|
||||
|
||||
if status:
|
||||
query = DocumentService.apply_display_status_filter(query, status)
|
||||
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
|
||||
|
||||
if search:
|
||||
search = f"%{search}%"
|
||||
@ -307,11 +269,10 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -328,23 +289,23 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=False, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=False, location="json")
|
||||
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
|
||||
parser.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
|
||||
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@ -388,10 +349,10 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
class DatasetInitApi(Resource):
|
||||
@console_ns.doc("init_dataset")
|
||||
@console_ns.doc(description="Initialize dataset with documents")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("init_dataset")
|
||||
@api.doc(description="Initialize dataset with documents")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DatasetInitRequest",
|
||||
{
|
||||
"upload_file_id": fields.String(required=True, description="Upload file ID"),
|
||||
@ -401,48 +362,47 @@ class DatasetInitApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "Dataset initialized successfully", dataset_and_document_model)
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@api.response(201, "Dataset initialized successfully", dataset_and_document_fields)
|
||||
@api.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_model)
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
required=True,
|
||||
nullable=False,
|
||||
location="json",
|
||||
)
|
||||
parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
|
||||
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=args["embedding_model_provider"],
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=args["embedding_model"],
|
||||
@ -459,9 +419,9 @@ class DatasetInitApi(Resource):
|
||||
|
||||
try:
|
||||
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -477,17 +437,16 @@ class DatasetInitApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-estimate")
|
||||
class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@console_ns.doc("estimate_document_indexing")
|
||||
@console_ns.doc(description="Estimate document indexing cost")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Document already finished")
|
||||
@api.doc("estimate_document_indexing")
|
||||
@api.doc(description="Estimate document indexing cost")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing estimate calculated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Document already finished")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -516,14 +475,14 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
|
||||
try:
|
||||
estimate_response = indexing_runner.indexing_estimate(
|
||||
current_tenant_id,
|
||||
current_user.current_tenant_id,
|
||||
[extract_setting],
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -552,7 +511,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, batch):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
batch = str(batch)
|
||||
documents = self.get_batch_documents(dataset_id, batch)
|
||||
@ -572,7 +530,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -580,7 +538,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
@ -588,16 +546,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION,
|
||||
notion_info=NotionInfo.model_validate(
|
||||
{
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_tenant_id,
|
||||
}
|
||||
),
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
},
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -605,17 +561,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE,
|
||||
website_info=WebsiteInfo.model_validate(
|
||||
{
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
}
|
||||
),
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
"url": data_source_info["url"],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info["mode"],
|
||||
"only_main_content": data_source_info["only_main_content"],
|
||||
},
|
||||
document_model=document.doc_form,
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
@ -625,7 +579,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
indexing_runner = IndexingRunner()
|
||||
try:
|
||||
response = indexing_runner.indexing_estimate(
|
||||
current_tenant_id,
|
||||
current_user.current_tenant_id,
|
||||
extract_settings,
|
||||
data_process_rule_dict,
|
||||
document.doc_form,
|
||||
@ -692,11 +646,11 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/indexing-status")
|
||||
class DocumentIndexingStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_indexing_status")
|
||||
@console_ns.doc(description="Get document indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@api.doc("get_document_indexing_status")
|
||||
@api.doc(description="Get document indexing status")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.response(200, "Indexing status retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -742,17 +696,17 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
class DocumentApi(DocumentResource):
|
||||
METADATA_CHOICES = {"all", "only", "without"}
|
||||
|
||||
@console_ns.doc("get_document")
|
||||
@console_ns.doc(description="Get document details")
|
||||
@console_ns.doc(
|
||||
@api.doc("get_document")
|
||||
@api.doc(description="Get document details")
|
||||
@api.doc(
|
||||
params={
|
||||
"dataset_id": "Dataset ID",
|
||||
"document_id": "Document ID",
|
||||
"metadata": "Metadata inclusion (all/only/without)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Document retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@api.response(200, "Document retrieved successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -782,7 +736,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -815,7 +769,7 @@ class DocumentApi(DocumentResource):
|
||||
"name": document.name,
|
||||
"created_from": document.created_from,
|
||||
"created_by": document.created_by,
|
||||
"created_at": int(document.created_at.timestamp()),
|
||||
"created_at": document.created_at.timestamp(),
|
||||
"tokens": document.tokens,
|
||||
"indexing_status": document.indexing_status,
|
||||
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
|
||||
@ -863,20 +817,19 @@ class DocumentApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@console_ns.doc("update_document_processing")
|
||||
@console_ns.doc(description="Update document processing status (pause/resume)")
|
||||
@console_ns.doc(
|
||||
@api.doc("update_document_processing")
|
||||
@api.doc(description="Update document processing status (pause/resume)")
|
||||
@api.doc(
|
||||
params={"dataset_id": "Dataset ID", "document_id": "Document ID", "action": "Action to perform (pause/resume)"}
|
||||
)
|
||||
@console_ns.response(200, "Processing status updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(400, "Invalid action")
|
||||
@api.response(200, "Processing status updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(400, "Invalid action")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -908,11 +861,11 @@ class DocumentProcessingApi(DocumentResource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
|
||||
class DocumentMetadataApi(DocumentResource):
|
||||
@console_ns.doc("update_document_metadata")
|
||||
@console_ns.doc(description="Update document metadata")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("update_document_metadata")
|
||||
@api.doc(description="Update document metadata")
|
||||
@api.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateDocumentMetadataRequest",
|
||||
{
|
||||
"doc_type": fields.String(description="Document type"),
|
||||
@ -920,14 +873,13 @@ class DocumentMetadataApi(DocumentResource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Document metadata updated successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@api.response(200, "Document metadata updated successfully")
|
||||
@api.response(404, "Document not found")
|
||||
@api.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
@ -975,7 +927,6 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
@ -1079,9 +1030,8 @@ class DocumentRetryApi(DocumentResource):
|
||||
def post(self, dataset_id):
|
||||
"""retry document."""
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"document_ids", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1123,14 +1073,14 @@ class DocumentRenameApi(DocumentResource):
|
||||
@marshal_with(document_fields)
|
||||
def post(self, dataset_id, document_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_operator_permission(current_user, dataset)
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
@ -1148,7 +1098,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -1157,7 +1106,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.tenant_id != current_tenant_id:
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
if document.data_source_type != "website_crawl":
|
||||
raise ValueError("Document is not a website document.")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
@ -26,7 +27,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
@ -42,8 +43,6 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -60,15 +59,13 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
.add_argument("enabled", type=str, default="all", location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
|
||||
parser.add_argument("enabled", type=str, default="all", location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -82,7 +79,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
select(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document_id),
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
DocumentSegment.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
)
|
||||
@ -118,8 +115,6 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -153,8 +148,6 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
@ -178,7 +171,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -211,8 +204,6 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -230,7 +221,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -246,12 +237,10 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.create_segment(args, document, dataset)
|
||||
@ -266,8 +255,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -285,7 +272,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -300,7 +287,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -313,18 +300,16 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
|
||||
parser.add_argument(
|
||||
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
SegmentService.segment_create_args_validate(args, document)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
|
||||
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@ -332,8 +317,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -350,7 +333,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -378,8 +361,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -391,9 +372,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"upload_file_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("upload_file_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
upload_file_id = args["upload_file_id"]
|
||||
|
||||
@ -416,7 +396,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
upload_file_id,
|
||||
dataset_id,
|
||||
document_id,
|
||||
current_tenant_id,
|
||||
current_user.current_tenant_id,
|
||||
current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -447,8 +427,6 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -463,7 +441,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -475,7 +453,7 @@ class ChildChunkAddApi(Resource):
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_model_instance(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model,
|
||||
@ -491,9 +469,8 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
@ -506,8 +483,6 @@ class ChildChunkAddApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id, segment_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -524,17 +499,15 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("limit", type=int, default=20, location="args")
|
||||
parser.add_argument("keyword", type=str, default=None, location="args")
|
||||
parser.add_argument("page", type=int, default=1, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -557,8 +530,6 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -575,7 +546,7 @@ class ChildChunkAddApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -588,13 +559,12 @@ class ChildChunkAddApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"chunks", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
chunks_data = args["chunks"]
|
||||
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
|
||||
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
|
||||
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
@ -610,8 +580,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -628,7 +596,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -639,7 +607,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -666,8 +634,6 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -684,7 +650,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not segment:
|
||||
@ -695,7 +661,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.tenant_id == current_user.current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
@ -711,9 +677,8 @@ class ChildChunkUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
# validate args
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
content = args["content"]
|
||||
|
||||
@ -1,76 +1,23 @@
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
def _build_dataset_detail_model():
|
||||
keyword_setting_model = _get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = _get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = _get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = _get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = _get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
tag_model = _get_or_create_model("Tag", tag_fields)
|
||||
doc_metadata_model = _get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
external_knowledge_info_model = _get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
external_retrieval_model = _get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
icon_info_model = _get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
return _get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
|
||||
try:
|
||||
dataset_detail_model = console_ns.models["DatasetDetail"]
|
||||
except KeyError:
|
||||
dataset_detail_model = _build_dataset_detail_model()
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 100:
|
||||
raise ValueError("Name must be between 1 to 100 characters.")
|
||||
@ -79,27 +26,26 @@ def _validate_name(name: str) -> str:
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api")
|
||||
class ExternalApiTemplateListApi(Resource):
|
||||
@console_ns.doc("get_external_api_templates")
|
||||
@console_ns.doc(description="Get external knowledge API templates")
|
||||
@console_ns.doc(
|
||||
@api.doc("get_external_api_templates")
|
||||
@api.doc(description="Get external knowledge API templates")
|
||||
@api.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"keyword": "Search keyword",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "External API templates retrieved successfully")
|
||||
@api.response(200, "External API templates retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
|
||||
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
|
||||
page, limit, current_tenant_id, search
|
||||
page, limit, current_user.current_tenant_id, search
|
||||
)
|
||||
response = {
|
||||
"data": [item.to_dict() for item in external_knowledge_apis],
|
||||
@ -114,23 +60,20 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -142,7 +85,7 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
try:
|
||||
external_knowledge_api = ExternalDatasetService.create_external_knowledge_api(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, args=args
|
||||
tenant_id=current_user.current_tenant_id, user_id=current_user.id, args=args
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@ -152,11 +95,11 @@ class ExternalApiTemplateListApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
|
||||
class ExternalApiTemplateApi(Resource):
|
||||
@console_ns.doc("get_external_api_template")
|
||||
@console_ns.doc(description="Get external knowledge API template details")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "External API template retrieved successfully")
|
||||
@console_ns.response(404, "Template not found")
|
||||
@api.doc("get_external_api_template")
|
||||
@api.doc(description="Get external knowledge API template details")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "External API template retrieved successfully")
|
||||
@api.response(404, "Template not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -172,31 +115,28 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"settings",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=False,
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
ExternalDatasetService.validate_api_list(args["settings"])
|
||||
|
||||
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
args=args,
|
||||
@ -208,22 +148,22 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, external_knowledge_api_id):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.is_editor or current_user.is_dataset_operator):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
class ExternalApiUseCheckApi(Resource):
|
||||
@console_ns.doc("check_external_api_usage")
|
||||
@console_ns.doc(description="Check if external knowledge API is being used")
|
||||
@console_ns.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@console_ns.response(200, "Usage check completed successfully")
|
||||
@api.doc("check_external_api_usage")
|
||||
@api.doc(description="Check if external knowledge API is being used")
|
||||
@api.doc(params={"external_knowledge_api_id": "External knowledge API ID"})
|
||||
@api.response(200, "Usage check completed successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -238,10 +178,10 @@ class ExternalApiUseCheckApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/external")
|
||||
class ExternalDatasetCreateApi(Resource):
|
||||
@console_ns.doc("create_external_dataset")
|
||||
@console_ns.doc(description="Create external knowledge dataset")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("create_external_dataset")
|
||||
@api.doc(description="Create external knowledge dataset")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateExternalDatasetRequest",
|
||||
{
|
||||
"external_knowledge_api_id": fields.String(required=True, description="External knowledge API ID"),
|
||||
@ -251,30 +191,29 @@ class ExternalDatasetCreateApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(201, "External dataset created successfully", dataset_detail_model)
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@api.response(201, "External dataset created successfully", dataset_detail_fields)
|
||||
@api.response(400, "Invalid parameters")
|
||||
@api.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("external_knowledge_api_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("external_knowledge_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="name is required. Name must be between 1 to 100 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument("description", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -284,7 +223,7 @@ class ExternalDatasetCreateApi(Resource):
|
||||
|
||||
try:
|
||||
dataset = ExternalDatasetService.create_external_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
args=args,
|
||||
)
|
||||
@ -296,11 +235,11 @@ class ExternalDatasetCreateApi(Resource):
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/external-hit-testing")
|
||||
class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@console_ns.doc("test_external_knowledge_retrieval")
|
||||
@console_ns.doc(description="Test external knowledge retrieval for dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("test_external_knowledge_retrieval")
|
||||
@api.doc(description="Test external knowledge retrieval for dataset")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ExternalHitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
@ -309,14 +248,13 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "External hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@api.response(200, "External hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -327,12 +265,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -341,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
response = HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
@ -354,10 +290,10 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@console_ns.route("/test/retrieval")
|
||||
class BedrockRetrievalApi(Resource):
|
||||
# this api is only for internal testing
|
||||
@console_ns.doc("bedrock_retrieval_test")
|
||||
@console_ns.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("bedrock_retrieval_test")
|
||||
@api.doc(description="Bedrock retrieval test (internal use only)")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"BedrockRetrievalTestRequest",
|
||||
{
|
||||
"retrieval_setting": fields.Raw(required=True, description="Retrieval settings"),
|
||||
@ -366,19 +302,17 @@ class BedrockRetrievalApi(Resource):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Bedrock retrieval test completed")
|
||||
@api.response(200, "Bedrock retrieval test completed")
|
||||
def post(self):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
|
||||
parser.add_argument(
|
||||
"query",
|
||||
nullable=False,
|
||||
required=True,
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Call the knowledge retrieval service
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@ -12,11 +12,11 @@ from libs.login import login_required
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(
|
||||
console_ns.model(
|
||||
@api.doc("test_dataset_retrieval")
|
||||
@api.doc(description="Test dataset knowledge retrieval")
|
||||
@api.doc(params={"dataset_id": "Dataset ID"})
|
||||
@api.expect(
|
||||
api.model(
|
||||
"HitTestingRequest",
|
||||
{
|
||||
"query": fields.String(required=True, description="Query text for testing"),
|
||||
@ -26,9 +26,9 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
},
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@api.response(200, "Hit testing completed successfully")
|
||||
@api.response(404, "Dataset not found")
|
||||
@api.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
@ -19,7 +21,6 @@ from core.errors.error import (
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -30,7 +31,6 @@ logger = logging.getLogger(__name__)
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -48,22 +48,20 @@ class DatasetsHitTestingBase:
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
account=cast(Account, current_user),
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
@ -23,14 +24,11 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
metadata_args = MetadataArgs(**args)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -61,8 +59,8 @@ class DatasetMetadataApi(Resource):
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
parser = reqparse.RequestParser().add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
name = args["name"]
|
||||
|
||||
@ -81,7 +79,6 @@ class DatasetMetadataApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -111,7 +108,6 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -132,18 +128,16 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json"
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
metadata_args = MetadataOperationData(**args)
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -20,11 +24,11 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
tenant_id = current_tenant_id
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
@ -48,7 +52,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
@ -121,30 +125,27 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@console_ns.expect(parser_datasource)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
args = parser_datasource.parse_args()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
@ -159,39 +160,31 @@ class DatasourceAuth(Resource):
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
args = parser_datasource_delete.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
@ -199,30 +192,23 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
args = parser_datasource_update.parse_args()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
@ -238,10 +224,10 @@ class DatasourceAuthListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
@ -251,35 +237,29 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource_custom.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
@ -290,63 +270,52 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(parser_update_name)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_update_name.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
@ -12,21 +12,9 @@ from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class Parser(BaseModel):
|
||||
inputs: dict
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
console_ns.schema_model(Parser.__name__, Parser.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview")
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Parser.__name__], validate=True)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -38,10 +26,19 @@ class DataSourceContentPreviewApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
args = Parser.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
inputs = args.inputs
|
||||
datasource_type = args.datasource_type
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
@ -50,6 +47,6 @@ class DataSourceContentPreviewApi(Resource):
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.credential_id,
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
@ -66,31 +66,29 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@ -125,28 +123,26 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -12,7 +13,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
@ -26,7 +27,9 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
@ -35,7 +38,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
@ -55,12 +58,12 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_tenant_id,
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
@ -78,12 +81,10 @@ class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
|
||||
@ -23,7 +23,7 @@ from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
@ -33,18 +33,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
@ -208,11 +206,10 @@ class RagPipelineVariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
|
||||
@ -1,16 +1,20 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
@ -21,31 +25,29 @@ class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("pipeline_id", type=str, location="json")
|
||||
)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
@ -58,9 +60,9 @@ class RagPipelineImportApi(Resource):
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED:
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING:
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -70,21 +72,22 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
@ -95,9 +98,11 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
@ -111,10 +116,13 @@ class RagPipelineExportApi(Resource):
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user