mirror of
https://github.com/langgenius/dify.git
synced 2026-02-02 18:06:38 +08:00
Compare commits
83 Commits
2.0.0-beta
...
build/bool
| Author | SHA1 | Date | |
|---|---|---|---|
| bdfb31f332 | |||
| f762c969b3 | |||
| 6b87b0fa42 | |||
| 4386c9f4f1 | |||
| 62955f4d11 | |||
| e38ba8403d | |||
| a1f278bdc8 | |||
| 65c0f8d657 | |||
| c3e4d5a059 | |||
| 30775109fd | |||
| 76659843d7 | |||
| 42258bf451 | |||
| 399d510e1b | |||
| 8462c79179 | |||
| f5f11f5a9e | |||
| 96e8321462 | |||
| af6e5e8663 | |||
| d4f976270d | |||
| dc3cbd9a74 | |||
| 056564c100 | |||
| d5b99610fc | |||
| 04d2b0775d | |||
| 6614e97b53 | |||
| ca37e304cc | |||
| 437729db97 | |||
| fbd0effa04 | |||
| 14c62850e0 | |||
| c4824a1d4c | |||
| d642e9a98e | |||
| 72c6195c10 | |||
| 2bd096b454 | |||
| fa98505e09 | |||
| 65ae25e09b | |||
| 4943d8bc93 | |||
| 168d82f9b0 | |||
| 0e3ccb4dcc | |||
| f7f6210a4c | |||
| ed8e06d1bf | |||
| 6287e3cd9e | |||
| 94c33c6eed | |||
| 383695b820 | |||
| 3751702efc | |||
| 111c34e50f | |||
| a2246be4f5 | |||
| 80e08562be | |||
| 6d03a15e0f | |||
| bd04ddd544 | |||
| cb51360aa3 | |||
| 269d34304e | |||
| 590c14c977 | |||
| 5fbeb275b2 | |||
| ff9d051635 | |||
| 1e18c828d9 | |||
| 5b895be412 | |||
| ea2a1b203c | |||
| 7bfdfe73c2 | |||
| 171c5055f8 | |||
| e3e4369358 | |||
| efa28453be | |||
| 37de3b1b68 | |||
| 75d4e519fc | |||
| cf63926e16 | |||
| 14eb5b7930 | |||
| 544ebde054 | |||
| 6012ad57ac | |||
| 70f6d8b42a | |||
| 1d738f3fa6 | |||
| 912d68a148 | |||
| 8c8c250570 | |||
| b5782fff8f | |||
| f2dfb5363f | |||
| fc0dea5647 | |||
| fcf0ae4c85 | |||
| 9ac629bcf5 | |||
| 79bdb22f79 | |||
| b72d977871 | |||
| 657d3d1dae | |||
| 3ba23238e5 | |||
| e1a7b59160 | |||
| 18941beb34 | |||
| 6ff18d13e6 | |||
| 1094b3da23 | |||
| 77aa35ff15 |
@ -1,19 +0,0 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [],
|
||||
"deny": []
|
||||
},
|
||||
"env": {
|
||||
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
},
|
||||
"enabledMcpjsonServers": [
|
||||
"context7",
|
||||
"sequential-thinking",
|
||||
"github",
|
||||
"fetch",
|
||||
"playwright",
|
||||
"ide"
|
||||
],
|
||||
"enableAllProjectMcpServers": true
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
corepack enable
|
||||
npm add -g pnpm@10.15.0
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
||||
34
.github/actions/setup-uv/action.yml
vendored
Normal file
34
.github/actions/setup-uv/action.yml
vendored
Normal file
@ -0,0 +1,34 @@
|
||||
name: Setup UV and Python
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version to use and the UV installed with
|
||||
required: true
|
||||
default: '3.12'
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.8.9'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
default: ''
|
||||
enable-cache:
|
||||
required: true
|
||||
default: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set up Python ${{ inputs.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
version: ${{ inputs.uv-version }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
enable-cache: ${{ inputs.enable-cache }}
|
||||
cache-dependency-glob: ${{ inputs.uv-lockfile }}
|
||||
28
.github/workflows/api-tests.yml
vendored
28
.github/workflows/api-tests.yml
vendored
@ -1,7 +1,13 @@
|
||||
name: Run Pytest
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/**
|
||||
- docker/**
|
||||
- .github/workflows/api-tests.yml
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -27,11 +33,10 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
uv-lockfile: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
@ -42,7 +47,11 @@ jobs:
|
||||
- name: Run Unit tests
|
||||
run: |
|
||||
uv run --project api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run ty check
|
||||
run: |
|
||||
cd api
|
||||
uv add --dev ty
|
||||
uv run ty check || true
|
||||
- name: Run pyrefly check
|
||||
run: |
|
||||
cd api
|
||||
@ -62,6 +71,15 @@ jobs:
|
||||
- name: Run dify config tests
|
||||
run: uv run --project api dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: MyPy Cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: api/.mypy_cache
|
||||
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
|
||||
|
||||
- name: Run MyPy Checks
|
||||
run: dev/mypy-check
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
|
||||
9
.github/workflows/autofix.yml
vendored
9
.github/workflows/autofix.yml
vendored
@ -1,7 +1,9 @@
|
||||
name: autofix.ci
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@ -13,9 +15,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# Use uv to ensure we have the same ruff version in CI and locally.
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
|
||||
- run: |
|
||||
cd api
|
||||
uv sync --dev
|
||||
@ -26,7 +26,6 @@ jobs:
|
||||
- name: ast-grep
|
||||
run: |
|
||||
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
|
||||
- name: mdformat
|
||||
run: |
|
||||
uvx mdformat .
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@ -8,8 +8,6 @@ on:
|
||||
- "deploy/enterprise"
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "deploy/rag-dev"
|
||||
- "feat/rag-2"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
20
.github/workflows/db-migration-test.yml
vendored
20
.github/workflows/db-migration-test.yml
vendored
@ -1,7 +1,13 @@
|
||||
name: DB Migration Test
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- plugins/beta
|
||||
paths:
|
||||
- api/migrations/**
|
||||
- .github/workflows/db-migration-test.yml
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
@ -19,20 +25,12 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
uv-lockfile: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api
|
||||
- name: Ensure Offline migration are supported
|
||||
run: |
|
||||
# upgrade
|
||||
uv run --directory api flask db upgrade 'base:head' --sql
|
||||
# downgrade
|
||||
uv run --directory api flask db downgrade 'head:base' --sql
|
||||
|
||||
- name: Prepare middleware env
|
||||
run: |
|
||||
|
||||
7
.github/workflows/deploy-dev.yml
vendored
7
.github/workflows/deploy-dev.yml
vendored
@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@ -12,13 +12,12 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
78
.github/workflows/main-ci.yml
vendored
78
.github/workflows/main-ci.yml
vendored
@ -1,78 +0,0 @@
|
||||
name: Main CI Pipeline
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
push:
|
||||
branches: ["main"]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# Check which paths were changed to determine which tests to run
|
||||
check-changes:
|
||||
name: Check Changed Files
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
web-changed: ${{ steps.changes.outputs.web }}
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
api:
|
||||
- 'api/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
web:
|
||||
- 'web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- 'api/uv.lock'
|
||||
- 'api/pyproject.toml'
|
||||
migration:
|
||||
- 'api/migrations/**'
|
||||
- '.github/workflows/db-migration-test.yml'
|
||||
|
||||
# Run tests in parallel
|
||||
api-tests:
|
||||
name: API Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.api-changed == 'true'
|
||||
uses: ./.github/workflows/api-tests.yml
|
||||
|
||||
web-tests:
|
||||
name: Web Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.web-changed == 'true'
|
||||
uses: ./.github/workflows/web-tests.yml
|
||||
|
||||
style-check:
|
||||
name: Style Check
|
||||
uses: ./.github/workflows/style.yml
|
||||
|
||||
vdb-tests:
|
||||
name: VDB Tests
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.vdb-changed == 'true'
|
||||
uses: ./.github/workflows/vdb-tests.yml
|
||||
|
||||
db-migration-test:
|
||||
name: DB Migration Test
|
||||
needs: check-changes
|
||||
if: needs.check-changes.outputs.migration-changed == 'true'
|
||||
uses: ./.github/workflows/db-migration-test.yml
|
||||
29
.github/workflows/style.yml
vendored
29
.github/workflows/style.yml
vendored
@ -1,7 +1,9 @@
|
||||
name: Style check
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: style-${{ github.head_ref || github.run_id }}
|
||||
@ -12,6 +14,7 @@ permissions:
|
||||
statuses: write
|
||||
contents: read
|
||||
|
||||
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
@ -33,28 +36,30 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
uv-lockfile: api/uv.lock
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Run Basedpyright Checks
|
||||
- name: Ruff check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: dev/basedpyright-check
|
||||
|
||||
- name: Run Mypy Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
@ -96,9 +101,7 @@ jobs:
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
run: pnpm run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
@ -67,22 +67,12 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
|
||||
|
||||
- name: Generate i18n type definitions
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run gen:i18n-types
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: Update i18n files and type definitions based on en-US changes
|
||||
title: 'chore: translate i18n files and update type definitions'
|
||||
body: |
|
||||
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
|
||||
|
||||
**Changes included:**
|
||||
- Updated translation files for all locales
|
||||
- Regenerated TypeScript type definitions for type safety
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
|
||||
15
.github/workflows/vdb-tests.yml
vendored
15
.github/workflows/vdb-tests.yml
vendored
@ -1,7 +1,15 @@
|
||||
name: Run VDB Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- api/core/rag/datasource/**
|
||||
- docker/**
|
||||
- .github/workflows/vdb-tests.yml
|
||||
- api/uv.lock
|
||||
- api/pyproject.toml
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -31,11 +39,10 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
uv-lockfile: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
11
.github/workflows/web-tests.yml
vendored
11
.github/workflows/web-tests.yml
vendored
@ -1,7 +1,11 @@
|
||||
name: Web Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- web/**
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -47,11 +51,6 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@ -123,12 +123,10 @@ venv.bak/
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# type checking
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
pyrightconfig.json
|
||||
!api/pyrightconfig.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
@ -197,8 +195,8 @@ sdks/python-client/dify_client.egg-info
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
web/.vscode
|
||||
# vscode Code History Extension
|
||||
.history
|
||||
|
||||
@ -216,13 +214,6 @@ mise.toml
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# PWA generated files
|
||||
web/public/sw.js
|
||||
web/public/sw.js.map
|
||||
web/public/workbox-*.js
|
||||
web/public/workbox-*.js.map
|
||||
web/public/fallback-*.js
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
api/.env.backup
|
||||
|
||||
34
.mcp.json
34
.mcp.json
@ -1,34 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"context7": {
|
||||
"type": "http",
|
||||
"url": "https://mcp.context7.com/mcp"
|
||||
},
|
||||
"sequential-thinking": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
"env": {}
|
||||
},
|
||||
"github": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {
|
||||
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
|
||||
}
|
||||
},
|
||||
"fetch": {
|
||||
"type": "stdio",
|
||||
"command": "uvx",
|
||||
"args": ["mcp-server-fetch"],
|
||||
"env": {}
|
||||
},
|
||||
"playwright": {
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "@playwright/mcp@latest"],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
|
||||
./dev/reformat # Run all formatters and linters
|
||||
uv run --project api ruff check --fix ./ # Fix linting issues
|
||||
uv run --project api ruff format ./ # Format code
|
||||
uv run --directory api basedpyright # Type checking
|
||||
uv run --project api mypy . # Type checking
|
||||
```
|
||||
|
||||
### Frontend (Web)
|
||||
@ -86,4 +86,3 @@ pnpm test # Run Jest tests
|
||||
## Project-Specific Conventions
|
||||
|
||||
- All async tasks use Celery with Redis as broker
|
||||
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.
|
||||
|
||||
60
Makefile
60
Makefile
@ -4,48 +4,6 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
|
||||
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
|
||||
VERSION=latest
|
||||
|
||||
# Backend Development Environment Setup
|
||||
.PHONY: dev-setup prepare-docker prepare-web prepare-api
|
||||
|
||||
# Default dev setup target
|
||||
dev-setup: prepare-docker prepare-web prepare-api
|
||||
@echo "✅ Backend development environment setup complete!"
|
||||
|
||||
# Step 1: Prepare Docker middleware
|
||||
prepare-docker:
|
||||
@echo "🐳 Setting up Docker middleware..."
|
||||
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
|
||||
@echo "✅ Docker middleware started"
|
||||
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cd web && pnpm build
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
prepare-api:
|
||||
@echo "🔧 Setting up API environment..."
|
||||
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
|
||||
@cd api && uv sync --dev
|
||||
@cd api && uv run flask db upgrade
|
||||
@echo "✅ API environment prepared (not started)"
|
||||
|
||||
# Clean dev environment
|
||||
dev-clean:
|
||||
@echo "⚠️ Stopping Docker containers..."
|
||||
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
|
||||
@echo "🗑️ Removing volumes..."
|
||||
@rm -rf docker/volumes/db
|
||||
@rm -rf docker/volumes/redis
|
||||
@rm -rf docker/volumes/plugin_daemon
|
||||
@rm -rf docker/volumes/weaviate
|
||||
@rm -rf api/storage
|
||||
@echo "✅ Cleanup complete"
|
||||
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
@ -81,21 +39,5 @@ build-push-web: build-web push-web
|
||||
build-push-all: build-all push-all
|
||||
@echo "All Docker images have been built and pushed."
|
||||
|
||||
# Help target
|
||||
help:
|
||||
@echo "Development Setup Targets:"
|
||||
@echo " make dev-setup - Run all setup steps for backend dev environment"
|
||||
@echo " make prepare-docker - Set up Docker middleware"
|
||||
@echo " make prepare-web - Set up web environment"
|
||||
@echo " make prepare-api - Set up API environment"
|
||||
@echo " make dev-clean - Stop Docker middleware containers"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
@echo " make build-web - Build web Docker image"
|
||||
@echo " make build-api - Build API Docker image"
|
||||
@echo " make build-all - Build all Docker images"
|
||||
@echo " make push-all - Push all Docker images"
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all
|
||||
|
||||
@ -180,7 +180,7 @@ docker compose up -d
|
||||
|
||||
## Contributing
|
||||
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。
|
||||
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
|
||||
|
||||
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。
|
||||
|
||||
@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline
|
||||
|
||||
## Contributing
|
||||
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
|
||||
> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
|
||||
@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @
|
||||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md).
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias.
|
||||
|
||||
> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart
|
||||
|
||||
## Contribuer
|
||||
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md).
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences.
|
||||
|
||||
> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -169,7 +169,7 @@ docker compose up -d
|
||||
|
||||
## 貢献
|
||||
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
|
||||
同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。
|
||||
|
||||
> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。
|
||||
|
||||
@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
||||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요.
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다.
|
||||
|
||||
> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요.
|
||||
|
||||
@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by
|
||||
|
||||
## Contribuindo
|
||||
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md).
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências.
|
||||
|
||||
> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz.
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
|
||||
Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün.
|
||||
|
||||
> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın.
|
||||
|
||||
@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
||||
|
||||
## 貢獻
|
||||
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。
|
||||
|
||||
> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。
|
||||
|
||||
@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De
|
||||
|
||||
## Đóng góp
|
||||
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi.
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
|
||||
Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị.
|
||||
|
||||
> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi.
|
||||
|
||||
@ -75,7 +75,6 @@ DB_PASSWORD=difyai123456
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_DATABASE=dify
|
||||
SQLALCHEMY_POOL_PRE_PING=true
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
@ -461,16 +460,6 @@ WORKFLOW_CALL_MAX_DEPTH=5
|
||||
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
# Seconds of idle time before scaling down workers (default: 5.0)
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
# Workflow storage configuration
|
||||
# Options: rdbms, hybrid
|
||||
# rdbms: Use only the relational database (default)
|
||||
@ -575,7 +564,3 @@ QUEUE_MONITOR_THRESHOLD=200
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
# Swagger UI configuration
|
||||
SWAGGER_UI_ENABLED=true
|
||||
SWAGGER_UI_PATH=/swagger-ui.html
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
[importlinter]
|
||||
root_packages =
|
||||
core
|
||||
configs
|
||||
controllers
|
||||
models
|
||||
tasks
|
||||
services
|
||||
|
||||
[importlinter:contract:workflow]
|
||||
name = Workflow
|
||||
type=layers
|
||||
layers =
|
||||
graph_engine
|
||||
graph_events
|
||||
graph
|
||||
nodes
|
||||
node_events
|
||||
entities
|
||||
containers =
|
||||
core.workflow
|
||||
ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
|
||||
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
core.workflow.nodes.node_factory -> core.workflow.graph
|
||||
|
||||
[importlinter:contract:rsc]
|
||||
name = RSC
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
response_coordinator
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:worker]
|
||||
name = Worker
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
worker
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:graph-engine-architecture]
|
||||
name = Graph Engine Architecture
|
||||
type = layers
|
||||
layers =
|
||||
graph_engine
|
||||
orchestration
|
||||
command_processing
|
||||
event_management
|
||||
error_handling
|
||||
graph_traversal
|
||||
state_management
|
||||
worker_management
|
||||
domain
|
||||
containers =
|
||||
core.workflow.graph_engine
|
||||
|
||||
[importlinter:contract:domain-isolation]
|
||||
name = Domain Model Isolation
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.domain
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
core.workflow.graph_engine.command_channels
|
||||
core.workflow.graph_engine.layers
|
||||
core.workflow.graph_engine.protocols
|
||||
|
||||
[importlinter:contract:worker-management]
|
||||
name = Worker Management
|
||||
type = forbidden
|
||||
source_modules =
|
||||
core.workflow.graph_engine.worker_management
|
||||
forbidden_modules =
|
||||
core.workflow.graph_engine.orchestration
|
||||
core.workflow.graph_engine.command_processing
|
||||
core.workflow.graph_engine.event_management
|
||||
|
||||
[importlinter:contract:error-handling-strategies]
|
||||
name = Error Handling Strategies
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.error_handling.abort_strategy
|
||||
core.workflow.graph_engine.error_handling.retry_strategy
|
||||
core.workflow.graph_engine.error_handling.fail_branch_strategy
|
||||
core.workflow.graph_engine.error_handling.default_value_strategy
|
||||
|
||||
[importlinter:contract:graph-traversal-components]
|
||||
name = Graph Traversal Components
|
||||
type = layers
|
||||
layers =
|
||||
edge_processor
|
||||
skip_propagator
|
||||
containers =
|
||||
core.workflow.graph_engine.graph_traversal
|
||||
|
||||
[importlinter:contract:command-channels]
|
||||
name = Command Channels Independence
|
||||
type = independence
|
||||
modules =
|
||||
core.workflow.graph_engine.command_channels.in_memory_channel
|
||||
core.workflow.graph_engine.command_channels.redis_channel
|
||||
@ -43,7 +43,6 @@ select = [
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
"G001", # don't use str format to logging messages
|
||||
"G003", # don't use + in logging messages
|
||||
"G004", # don't use f-strings to format logging messages
|
||||
]
|
||||
|
||||
|
||||
@ -97,16 +97,8 @@ uv run celery -A app.celery beat
|
||||
uv sync --dev
|
||||
```
|
||||
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
|
||||
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
uv run pytest # Run all tests
|
||||
uv run pytest tests/unit_tests/ # Unit tests only
|
||||
uv run pytest tests/integration_tests/ # Integration tests
|
||||
|
||||
# Code quality
|
||||
../dev/reformat # Run all formatters and linters
|
||||
uv run ruff check --fix ./ # Fix linting issues
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
uv run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
27
api/app.py
27
api/app.py
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -16,20 +17,20 @@ else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@ -5,8 +5,6 @@ from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@ -25,9 +23,6 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Capture the decorator's return value to avoid pyright reportUnusedFunction
|
||||
_ = before_request
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
@ -37,7 +32,7 @@ def create_app() -> DifyApp:
|
||||
initialize_extensions(app)
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return app
|
||||
|
||||
|
||||
@ -98,14 +93,14 @@ def initialize_extensions(app: DifyApp):
|
||||
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
|
||||
if not is_enabled:
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Skipped %s", short_name)
|
||||
logging.info("Skipped %s", short_name)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
ext.init_app(app)
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
|
||||
logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
|
||||
|
||||
|
||||
def create_migrations_app():
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
import logging
|
||||
|
||||
import psycogreen.gevent as pscycogreen_gevent # type: ignore
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(message: str):
|
||||
print(message, flush=True)
|
||||
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
_log("gRPC patched with gevent.")
|
||||
pscycogreen_gevent.patch_psycopg()
|
||||
_log("psycopg2 patched with gevent.")
|
||||
|
||||
|
||||
from app import app, celery
|
||||
|
||||
__all__ = ["app", "celery"]
|
||||
11
api/child_class.py
Normal file
11
api/child_class.py
Normal file
@ -0,0 +1,11 @@
|
||||
from tests.integration_tests.utils.parent_class import ParentClass
|
||||
|
||||
|
||||
class ChildClass(ParentClass):
|
||||
"""Test child class for module import helper tests"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
|
||||
def get_name(self):
|
||||
return f"Child: {self.name}"
|
||||
275
api/commands.py
275
api/commands.py
@ -13,13 +13,11 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from core.helper import encrypter
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.models.document import Document
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from events.app_event import app_was_created
|
||||
from extensions.ext_database import db
|
||||
@ -32,20 +30,14 @@ from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from models.provider import Provider, ProviderModel
|
||||
from models.provider_ids import DatasourceProviderID, ToolProviderID
|
||||
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@ -577,7 +569,7 @@ def old_metadata_migration():
|
||||
for document in documents:
|
||||
if document.doc_metadata:
|
||||
doc_metadata = document.doc_metadata
|
||||
for key in doc_metadata:
|
||||
for key, value in doc_metadata.items():
|
||||
for field in BuiltInField:
|
||||
if field.value == key:
|
||||
break
|
||||
@ -693,7 +685,7 @@ def upgrade_db():
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to execute database migration")
|
||||
logging.exception("Failed to execute database migration")
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
@ -741,7 +733,7 @@ where sites.id is null limit 1000"""
|
||||
except Exception:
|
||||
failed_app_ids.append(app_id)
|
||||
click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red"))
|
||||
logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||
logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
|
||||
continue
|
||||
|
||||
if not processed_count:
|
||||
@ -1239,17 +1231,15 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
Count orphaned draft variables by app.
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics about orphaned variables and files
|
||||
Dictionary with statistics about orphaned variables
|
||||
"""
|
||||
# Count orphaned variables by app
|
||||
variables_query = """
|
||||
query = """
|
||||
SELECT
|
||||
wdv.app_id,
|
||||
COUNT(*) as variable_count,
|
||||
COUNT(wdv.file_id) as file_count
|
||||
COUNT(*) as variable_count
|
||||
FROM workflow_draft_variables AS wdv
|
||||
WHERE NOT EXISTS(
|
||||
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
|
||||
@ -1259,21 +1249,14 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app = {}
|
||||
total_files = 0
|
||||
result = conn.execute(sa.text(query))
|
||||
orphaned_by_app = {row[0]: row[1] for row in result}
|
||||
|
||||
for row in result:
|
||||
app_id, variable_count, file_count = row
|
||||
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
|
||||
total_files += file_count
|
||||
|
||||
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
|
||||
total_orphaned = sum(orphaned_by_app.values())
|
||||
app_count = len(orphaned_by_app)
|
||||
|
||||
return {
|
||||
"total_orphaned_variables": total_orphaned,
|
||||
"total_orphaned_files": total_files,
|
||||
"orphaned_app_count": app_count,
|
||||
"orphaned_by_app": orphaned_by_app,
|
||||
}
|
||||
@ -1302,7 +1285,6 @@ def cleanup_orphaned_draft_variables(
|
||||
stats = _count_orphaned_draft_variables()
|
||||
|
||||
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
|
||||
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
|
||||
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
|
||||
|
||||
if stats["total_orphaned_variables"] == 0:
|
||||
@ -1311,10 +1293,10 @@ def cleanup_orphaned_draft_variables(
|
||||
|
||||
if dry_run:
|
||||
logger.info("DRY RUN: Would delete the following:")
|
||||
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
|
||||
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
|
||||
:10
|
||||
]: # Show top 10
|
||||
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
|
||||
logger.info(" App %s: %s variables", app_id, count)
|
||||
if len(stats["orphaned_by_app"]) > 10:
|
||||
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
|
||||
return
|
||||
@ -1323,8 +1305,7 @@ def cleanup_orphaned_draft_variables(
|
||||
if not force:
|
||||
click.confirm(
|
||||
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
|
||||
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
|
||||
f"from {stats['orphaned_app_count']} apps?",
|
||||
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
|
||||
abort=True,
|
||||
)
|
||||
|
||||
@ -1357,231 +1338,3 @@ def cleanup_orphaned_draft_variables(
|
||||
continue
|
||||
|
||||
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
|
||||
|
||||
|
||||
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
|
||||
@click.option("--provider", prompt=True, help="Provider name")
|
||||
@click.option("--client-params", prompt=True, help="Client Params")
|
||||
def setup_datasource_oauth_client(provider, client_params):
|
||||
"""
|
||||
Setup datasource oauth client
|
||||
"""
|
||||
provider_id = DatasourceProviderID(provider)
|
||||
provider_name = provider_id.provider_name
|
||||
plugin_id = provider_id.plugin_id
|
||||
|
||||
try:
|
||||
# json validate
|
||||
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
|
||||
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
|
||||
click.echo(click.style("Client params validated successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
|
||||
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
|
||||
deleted_count = (
|
||||
db.session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
if deleted_count > 0:
|
||||
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
|
||||
|
||||
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
|
||||
oauth_client = DatasourceOauthParamConfig(
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
system_credentials=client_params_dict,
|
||||
)
|
||||
db.session.add(oauth_client)
|
||||
db.session.commit()
|
||||
click.echo(click.style(f"provider: {provider_name}", fg="green"))
|
||||
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
|
||||
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
|
||||
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
|
||||
|
||||
|
||||
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
|
||||
def transform_datasource_credentials():
|
||||
"""
|
||||
Transform datasource credentials
|
||||
"""
|
||||
try:
|
||||
installer_manager = PluginInstaller()
|
||||
plugin_migration = PluginMigration()
|
||||
|
||||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
# deal notion credentials
|
||||
deal_notion_count = 0
|
||||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for credential in notion_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if notion_plugin_id not in installed_plugins_ids:
|
||||
if notion_plugin_unique_identifier:
|
||||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = credential.access_token
|
||||
# notion info
|
||||
notion_info = credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
new_credentials = {
|
||||
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
|
||||
"workspace_id": workspace_id,
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="notion_datasource",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=notion_plugin_id,
|
||||
auth_type=oauth_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url=workspace_icon or "default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_notion_count += 1
|
||||
db.session.commit()
|
||||
# deal firecrawl credentials
|
||||
deal_firecrawl_count = 0
|
||||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in firecrawl_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if firecrawl_plugin_id not in installed_plugins_ids:
|
||||
if firecrawl_plugin_unique_identifier:
|
||||
# install firecrawl plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
"firecrawl_api_key": api_key,
|
||||
"base_url": base_url,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="firecrawl",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=firecrawl_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_firecrawl_count += 1
|
||||
db.session.commit()
|
||||
# deal jina credentials
|
||||
deal_jina_count = 0
|
||||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in jina_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
if jina_plugin_id not in installed_plugins_ids:
|
||||
if jina_plugin_unique_identifier:
|
||||
# install jina plugin
|
||||
print(jina_plugin_unique_identifier)
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
}
|
||||
datasource_provider = DatasourceProvider(
|
||||
provider="jina",
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=jina_plugin_id,
|
||||
auth_type=api_key_credential_type.value,
|
||||
encrypted_credentials=new_credentials,
|
||||
name=f"Auth {auth_count}",
|
||||
avatar_url="default",
|
||||
is_default=False,
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
deal_jina_count += 1
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
return
|
||||
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
|
||||
click.echo(
|
||||
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
|
||||
)
|
||||
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
|
||||
|
||||
|
||||
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
|
||||
@click.option(
|
||||
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
|
||||
)
|
||||
@click.option(
|
||||
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
|
||||
)
|
||||
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
|
||||
def install_rag_pipeline_plugins(input_file, output_file, workers):
|
||||
"""
|
||||
Install rag pipeline plugins
|
||||
"""
|
||||
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
|
||||
plugin_migration = PluginMigration()
|
||||
plugin_migration.install_rag_pipeline_plugins(
|
||||
input_file,
|
||||
output_file,
|
||||
workers,
|
||||
)
|
||||
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
@ -499,22 +499,6 @@ class UpdateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class WorkflowVariableTruncationConfig(BaseSettings):
|
||||
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
|
||||
# 100KB
|
||||
1024_000,
|
||||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
|
||||
50000,
|
||||
description="maximum length for string to trigger tuncation, measure in number of characters",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
|
||||
100,
|
||||
description="maximum length for array to trigger truncation.",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for workflow execution
|
||||
@ -545,28 +529,6 @@ class WorkflowConfig(BaseSettings):
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
description="Maximum number of workers per GraphEngine instance",
|
||||
default=10,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
|
||||
description="Queue depth threshold that triggers worker scale up",
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
|
||||
description="Seconds of idle time before scaling down workers",
|
||||
default=5.0,
|
||||
ge=0.1,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionConfig(BaseSettings):
|
||||
"""
|
||||
@ -1014,18 +976,6 @@ class WorkflowLogConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SwaggerUIConfig(BaseSettings):
|
||||
SWAGGER_UI_ENABLED: bool = Field(
|
||||
description="Whether to enable Swagger UI in api module",
|
||||
default=True,
|
||||
)
|
||||
|
||||
SWAGGER_UI_PATH: str = Field(
|
||||
description="Swagger UI page path in api module",
|
||||
default="/swagger-ui.html",
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@ -1057,12 +1007,10 @@ class FeatureConfig(
|
||||
WorkspaceConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
CeleryScheduleTasksConfig,
|
||||
WorkflowLogConfig,
|
||||
WorkflowVariableTruncationConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@ -222,28 +222,11 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
||||
@ -215,7 +215,6 @@ class DatabaseConfig(BaseSettings):
|
||||
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
|
||||
"connect_args": connect_args,
|
||||
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
|
||||
"pool_reset_on_return": None,
|
||||
}
|
||||
|
||||
|
||||
@ -300,7 +299,8 @@ class DatasetQueueMonitorConfig(BaseSettings):
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||
CeleryConfig,
|
||||
DatabaseConfig,
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
# configs of storage and storage providers
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ClickzettaConfig(BaseSettings):
|
||||
class ClickzettaConfig(BaseModel):
|
||||
"""
|
||||
Clickzetta Lakehouse vector database configuration
|
||||
"""
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseSettings):
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from configs.packaging.pyproject import PyProjectTomlConfig
|
||||
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
|
||||
|
||||
|
||||
class PackagingInfo(PyProjectTomlConfig):
|
||||
|
||||
@ -4,9 +4,8 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable, Mapping
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .python_3x import http_request, makedirs_wrapper
|
||||
from .utils import (
|
||||
@ -26,13 +25,13 @@ logger = logging.getLogger(__name__)
|
||||
class ApolloClient:
|
||||
def __init__(
|
||||
self,
|
||||
config_url: str,
|
||||
app_id: str,
|
||||
cluster: str = "default",
|
||||
secret: str = "",
|
||||
start_hot_update: bool = True,
|
||||
change_listener: Callable[[str, str, str, Any], None] | None = None,
|
||||
_notification_map: dict[str, int] | None = None,
|
||||
config_url,
|
||||
app_id,
|
||||
cluster="default",
|
||||
secret="",
|
||||
start_hot_update=True,
|
||||
change_listener=None,
|
||||
_notification_map=None,
|
||||
):
|
||||
# Core routing parameters
|
||||
self.config_url = config_url
|
||||
@ -48,17 +47,17 @@ class ApolloClient:
|
||||
# Private control variables
|
||||
self._cycle_time = 5
|
||||
self._stopping = False
|
||||
self._cache: dict[str, dict[str, Any]] = {}
|
||||
self._no_key: dict[str, str] = {}
|
||||
self._hash: dict[str, str] = {}
|
||||
self._cache = {}
|
||||
self._no_key = {}
|
||||
self._hash = {}
|
||||
self._pull_timeout = 75
|
||||
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||
self._long_poll_thread: threading.Thread | None = None
|
||||
self._long_poll_thread = None
|
||||
self._change_listener = change_listener # "add" "delete" "update"
|
||||
if _notification_map is None:
|
||||
_notification_map = {"application": -1}
|
||||
self._notification_map = _notification_map
|
||||
self.last_release_key: str | None = None
|
||||
self.last_release_key = None
|
||||
# Private startup method
|
||||
self._path_checker()
|
||||
if start_hot_update:
|
||||
@ -69,7 +68,7 @@ class ApolloClient:
|
||||
heartbeat.daemon = True
|
||||
heartbeat.start()
|
||||
|
||||
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
|
||||
def get_json_from_net(self, namespace="application"):
|
||||
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||
)
|
||||
@ -89,7 +88,7 @@ class ApolloClient:
|
||||
logger.exception("an error occurred in get_json_from_net")
|
||||
return None
|
||||
|
||||
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
|
||||
def get_value(self, key, default_val=None, namespace="application"):
|
||||
try:
|
||||
# read memory configuration
|
||||
namespace_cache = self._cache.get(namespace)
|
||||
@ -105,8 +104,7 @@ class ApolloClient:
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
val = get_value_from_dict(namespace_data, key)
|
||||
if val is not None:
|
||||
if namespace_data is not None:
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
self._update_cache_and_file(namespace_data, namespace)
|
||||
return val
|
||||
|
||||
# read the file configuration
|
||||
@ -128,23 +126,23 @@ class ApolloClient:
|
||||
# to ensure the real-time correctness of the function call.
|
||||
# If the user does not have the same default val twice
|
||||
# and the default val is used here, there may be a problem.
|
||||
def _set_local_cache_none(self, namespace: str, key: str) -> None:
|
||||
def _set_local_cache_none(self, namespace, key):
|
||||
no_key = no_key_cache_key(namespace, key)
|
||||
self._no_key[no_key] = key
|
||||
|
||||
def _start_hot_update(self) -> None:
|
||||
def _start_hot_update(self):
|
||||
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||
# when the main thread is launched.
|
||||
self._long_poll_thread.daemon = True
|
||||
self._long_poll_thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
def stop(self):
|
||||
self._stopping = True
|
||||
logger.info("Stopping listener...")
|
||||
|
||||
# Call the set callback function, and if it is abnormal, try it out
|
||||
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
|
||||
def _call_listener(self, namespace, old_kv, new_kv):
|
||||
if self._change_listener is None:
|
||||
return
|
||||
if old_kv is None:
|
||||
@ -170,12 +168,12 @@ class ApolloClient:
|
||||
except BaseException as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _path_checker(self) -> None:
|
||||
def _path_checker(self):
|
||||
if not os.path.isdir(self._cache_file_path):
|
||||
makedirs_wrapper(self._cache_file_path)
|
||||
|
||||
# update the local cache and file cache
|
||||
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
|
||||
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||
# update the local cache
|
||||
self._cache[namespace] = namespace_data
|
||||
# update the file cache
|
||||
@ -189,7 +187,7 @@ class ApolloClient:
|
||||
self._hash[namespace] = new_hash
|
||||
|
||||
# get the configuration from the local file
|
||||
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
|
||||
def _get_local_cache(self, namespace="application"):
|
||||
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||
if os.path.isfile(cache_file_path):
|
||||
with open(cache_file_path) as f:
|
||||
@ -197,8 +195,8 @@ class ApolloClient:
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _long_poll(self) -> None:
|
||||
notifications: list[dict[str, Any]] = []
|
||||
def _long_poll(self):
|
||||
notifications = []
|
||||
for key in self._cache:
|
||||
namespace_data = self._cache[key]
|
||||
notification_id = -1
|
||||
@ -238,7 +236,7 @@ class ApolloClient:
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
|
||||
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||
namespace_data = self.get_json_from_net(namespace)
|
||||
if not namespace_data:
|
||||
return
|
||||
@ -250,7 +248,7 @@ class ApolloClient:
|
||||
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||
self._call_listener(namespace, old_kv, new_kv)
|
||||
|
||||
def _listener(self) -> None:
|
||||
def _listener(self):
|
||||
logger.info("start long_poll")
|
||||
while not self._stopping:
|
||||
self._long_poll()
|
||||
@ -268,13 +266,13 @@ class ApolloClient:
|
||||
headers["Timestamp"] = time_unix_now
|
||||
return headers
|
||||
|
||||
def _heart_beat(self) -> None:
|
||||
def _heart_beat(self):
|
||||
while not self._stopping:
|
||||
for namespace in self._notification_map:
|
||||
self._do_heart_beat(namespace)
|
||||
time.sleep(60 * 10) # 10 minutes
|
||||
|
||||
def _do_heart_beat(self, namespace: str) -> None:
|
||||
def _do_heart_beat(self, namespace):
|
||||
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
|
||||
try:
|
||||
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||
@ -294,7 +292,7 @@ class ApolloClient:
|
||||
logger.exception("an error occurred in _do_heart_beat")
|
||||
return None
|
||||
|
||||
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
|
||||
def get_all_dicts(self, namespace):
|
||||
namespace_data = self._cache.get(namespace)
|
||||
if namespace_data is None:
|
||||
net_namespace_data = self.get_json_from_net(namespace)
|
||||
|
||||
@ -2,8 +2,6 @@ import logging
|
||||
import os
|
||||
import ssl
|
||||
import urllib.request
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from urllib import parse
|
||||
from urllib.error import HTTPError
|
||||
|
||||
@ -21,9 +19,9 @@ urllib.request.install_opener(opener)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
|
||||
def http_request(url, timeout, headers={}):
|
||||
try:
|
||||
request = urllib.request.Request(url, headers=dict(headers))
|
||||
request = urllib.request.Request(url, headers=headers)
|
||||
res = urllib.request.urlopen(request, timeout=timeout)
|
||||
body = res.read().decode("utf-8")
|
||||
return res.code, body
|
||||
@ -35,9 +33,9 @@ def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}
|
||||
raise e
|
||||
|
||||
|
||||
def url_encode(params: dict[str, Any]) -> str:
|
||||
def url_encode(params):
|
||||
return parse.urlencode(params)
|
||||
|
||||
|
||||
def makedirs_wrapper(path: str) -> None:
|
||||
def makedirs_wrapper(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import hashlib
|
||||
import socket
|
||||
from typing import Any
|
||||
|
||||
from .python_3x import url_encode
|
||||
|
||||
@ -11,7 +10,7 @@ NAMESPACE_NAME = "namespaceName"
|
||||
|
||||
|
||||
# add timestamps uris and keys
|
||||
def signature(timestamp: str, uri: str, secret: str) -> str:
|
||||
def signature(timestamp, uri, secret):
|
||||
import base64
|
||||
import hmac
|
||||
|
||||
@ -20,16 +19,16 @@ def signature(timestamp: str, uri: str, secret: str) -> str:
|
||||
return base64.b64encode(hmac_code).decode()
|
||||
|
||||
|
||||
def url_encode_wrapper(params: dict[str, Any]) -> str:
|
||||
def url_encode_wrapper(params):
|
||||
return url_encode(params)
|
||||
|
||||
|
||||
def no_key_cache_key(namespace: str, key: str) -> str:
|
||||
def no_key_cache_key(namespace, key):
|
||||
return f"{namespace}{len(namespace)}{key}"
|
||||
|
||||
|
||||
# Returns whether the obtained value is obtained, and None if it does not
|
||||
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
|
||||
def get_value_from_dict(namespace_cache, key):
|
||||
if namespace_cache:
|
||||
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||
if kv_data is None:
|
||||
@ -39,7 +38,7 @@ def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any
|
||||
return None
|
||||
|
||||
|
||||
def init_ip() -> str:
|
||||
def init_ip():
|
||||
ip = ""
|
||||
s = None
|
||||
try:
|
||||
|
||||
@ -11,5 +11,5 @@ class RemoteSettingsSource:
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
return value
|
||||
|
||||
@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||
|
||||
from .utils import parse_config
|
||||
from .utils import _parse_config
|
||||
|
||||
|
||||
class NacosSettingsSource(RemoteSettingsSource):
|
||||
def __init__(self, configs: Mapping[str, Any]):
|
||||
self.configs = configs
|
||||
self.remote_configs: dict[str, str] = {}
|
||||
self.remote_configs: dict[str, Any] = {}
|
||||
self.async_init()
|
||||
|
||||
def async_init(self) -> None:
|
||||
def async_init(self):
|
||||
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
|
||||
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
|
||||
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
|
||||
@ -29,19 +29,22 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
try:
|
||||
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
|
||||
self.remote_configs = self._parse_config(content)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occurred")
|
||||
raise
|
||||
|
||||
def _parse_config(self, content: str) -> dict[str, str]:
|
||||
def _parse_config(self, content: str) -> dict:
|
||||
if not content:
|
||||
return {}
|
||||
try:
|
||||
return parse_config(content)
|
||||
return _parse_config(self, content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
return None, field_name, False
|
||||
|
||||
@ -17,26 +17,20 @@ class NacosHttpClient:
|
||||
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
|
||||
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
|
||||
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
|
||||
self.token: str | None = None
|
||||
self.token = None
|
||||
self.token_ttl = 18000
|
||||
self.token_expire_time: float = 0
|
||||
|
||||
def http_request(
|
||||
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
|
||||
) -> str:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
if params is None:
|
||||
params = {}
|
||||
def http_request(self, url, method="GET", headers=None, params=None):
|
||||
try:
|
||||
self._inject_auth_info(headers, params)
|
||||
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except requests.RequestException as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Request to Nacos failed: {e}"
|
||||
|
||||
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
|
||||
def _inject_auth_info(self, headers, params, module="config"):
|
||||
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
|
||||
|
||||
if module == "login":
|
||||
@ -51,17 +45,16 @@ class NacosHttpClient:
|
||||
headers["timeStamp"] = ts
|
||||
if self.username and self.password:
|
||||
self.get_access_token(force_refresh=False)
|
||||
if self.token is not None:
|
||||
params["accessToken"] = self.token
|
||||
params["accessToken"] = self.token
|
||||
|
||||
def __do_sign(self, sign_str: str, sk: str) -> str:
|
||||
def __do_sign(self, sign_str, sk):
|
||||
return (
|
||||
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
|
||||
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
|
||||
def get_sign_str(self, group, tenant, ts):
|
||||
sign_str = ""
|
||||
if tenant:
|
||||
sign_str = tenant + "+"
|
||||
@ -70,7 +63,7 @@ class NacosHttpClient:
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh: bool = False) -> str | None:
|
||||
def get_access_token(self, force_refresh=False):
|
||||
current_time = time.time()
|
||||
if self.token and not force_refresh and self.token_expire_time > current_time:
|
||||
return self.token
|
||||
@ -84,7 +77,6 @@ class NacosHttpClient:
|
||||
self.token = response_data.get("accessToken")
|
||||
self.token_ttl = response_data.get("tokenTtl", 18000)
|
||||
self.token_expire_time = current_time + self.token_ttl - 10
|
||||
return self.token
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("[get-access-token] exception occur")
|
||||
raise
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
def parse_config(content: str) -> dict[str, str]:
|
||||
def _parse_config(self, content: str) -> dict[str, str]:
|
||||
config: dict[str, str] = {}
|
||||
if not content:
|
||||
return config
|
||||
|
||||
@ -19,7 +19,6 @@ language_timezone_mapping = {
|
||||
"fa-IR": "Asia/Tehran",
|
||||
"sl-SI": "Europe/Ljubljana",
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@ -3,7 +3,6 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
@ -43,7 +43,7 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
|
||||
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
|
||||
|
||||
# Import other controllers
|
||||
from . import admin, apikey, extension, feature, ping, setup, spec, version
|
||||
from . import admin, apikey, extension, feature, ping, setup, version
|
||||
|
||||
# Import app controllers
|
||||
from .app import (
|
||||
@ -70,7 +70,7 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing, compliance
|
||||
@ -84,17 +84,9 @@ from .datasets import (
|
||||
external,
|
||||
hit_testing,
|
||||
metadata,
|
||||
upload_file,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_draft_variable,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
@ -8,8 +6,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
@ -18,9 +14,9 @@ from extensions.ext_database import db
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -134,19 +130,15 @@ class InsertExploreAppApi(Resource):
|
||||
app.is_public = False
|
||||
|
||||
with Session(db.engine) as session:
|
||||
installed_apps = (
|
||||
session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
installed_apps = session.execute(
|
||||
select(InstalledApp).where(
|
||||
InstalledApp.app_id == recommended_app.app_id,
|
||||
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
for installed_app in installed_apps:
|
||||
session.delete(installed_app)
|
||||
for installed_app in installed_apps:
|
||||
db.session.delete(installed_app)
|
||||
|
||||
db.session.delete(recommended_app)
|
||||
db.session.commit()
|
||||
|
||||
@ -84,10 +84,10 @@ class BaseApiKeyListResource(Resource):
|
||||
flask_restx.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
custom="max_keys_exceeded",
|
||||
code="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
api_token = ApiToken()
|
||||
setattr(api_token, self.resource_id_field, resource_id)
|
||||
api_token.tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -237,14 +237,9 @@ class AppExportApi(Resource):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
|
||||
parser.add_argument("workflow_id", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {
|
||||
"data": AppDslService.export_dsl(
|
||||
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
|
||||
)
|
||||
}
|
||||
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
|
||||
|
||||
|
||||
class AppNameApi(Resource):
|
||||
|
||||
@ -31,8 +31,6 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageAudioApi(Resource):
|
||||
@setup_required
|
||||
@ -51,7 +49,7 @@ class ChatMessageAudioApi(Resource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -72,7 +70,7 @@ class ChatMessageAudioApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to handle post request to ChatMessageAudioApi")
|
||||
logging.exception("Failed to handle post request to ChatMessageAudioApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -99,7 +97,7 @@ class ChatMessageTextApi(Resource):
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -120,7 +118,7 @@ class ChatMessageTextApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to handle post request to ChatMessageTextApi")
|
||||
logging.exception("Failed to handle post request to ChatMessageTextApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -162,7 +160,7 @@ class TextModesApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("Failed to handle get request to TextModesApi")
|
||||
logging.exception("Failed to handle get request to TextModesApi")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -34,8 +34,6 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion message api for user
|
||||
class CompletionMessageApi(Resource):
|
||||
@ -69,7 +67,7 @@ class CompletionMessageApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -82,7 +80,7 @@ class CompletionMessageApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -136,7 +134,7 @@ class ChatMessageApi(Resource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -151,7 +149,7 @@ class ChatMessageApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def delete(self, app_model, conversation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -16,10 +16,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class RuleGenerateApi(Resource):
|
||||
@ -138,6 +135,9 @@ class InstructionGenerateApi(Resource):
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
|
||||
from models import App, db
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
app = db.session.query(App).where(App.id == args["flow_id"]).first()
|
||||
if not app:
|
||||
return {"error": f"app {args['flow_id']} not found"}, 400
|
||||
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
def post(self) -> dict:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, default=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -3,7 +3,6 @@ import logging
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields, marshal_with, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from sqlalchemy import exists, select
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@ -34,8 +33,6 @@ from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessageListApi(Resource):
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
@ -95,22 +92,21 @@ class ChatMessageListApi(Resource):
|
||||
.all()
|
||||
)
|
||||
|
||||
# Initialize has_more based on whether we have a full page
|
||||
has_more = False
|
||||
if len(history_messages) == args["limit"]:
|
||||
current_page_first_message = history_messages[-1]
|
||||
# Check if there are more messages before the current page
|
||||
has_more = db.session.scalar(
|
||||
select(
|
||||
exists().where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
else:
|
||||
# If we don't have a full page, there are no more messages
|
||||
has_more = False
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
@ -130,7 +126,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
message_id = str(args["message_id"])
|
||||
|
||||
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -219,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
@ -73,7 +72,6 @@ class DraftWorkflowApi(Resource):
|
||||
Get draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -96,7 +94,6 @@ class DraftWorkflowApi(Resource):
|
||||
Sync draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -174,7 +171,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
Run draft workflow
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -209,7 +205,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -222,12 +218,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -245,7 +242,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -259,11 +256,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -281,7 +279,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -294,13 +292,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -318,7 +316,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -331,13 +329,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
@ -355,7 +353,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -368,13 +366,13 @@ class DraftWorkflowRunApi(Resource):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
@ -407,19 +405,11 @@ class WorkflowTaskStopApi(Resource):
|
||||
"""
|
||||
Stop workflow task
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -434,13 +424,13 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
||||
@ -482,9 +472,6 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -504,12 +491,13 @@ class PublishedWorkflowApi(Resource):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
@ -532,7 +520,7 @@ class PublishedWorkflowApi(Resource):
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
@ -553,9 +541,6 @@ class DefaultBlockConfigsApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -574,12 +559,13 @@ class DefaultBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("q", type=str, location="args")
|
||||
args = parser.parse_args()
|
||||
@ -609,12 +595,13 @@ class ConvertToWorkflowApi(Resource):
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
if request.data:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
|
||||
@ -658,9 +645,6 @@ class PublishedAllWorkflowApi(Resource):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -709,12 +693,13 @@ class WorkflowByIdApi(Resource):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
@ -765,12 +750,13 @@ class WorkflowByIdApi(Resource):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
|
||||
)
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import NoReturn
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
@ -13,17 +13,14 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
@ -31,7 +28,7 @@ from services.workflow_service import WorkflowService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
@ -42,7 +39,7 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
@ -76,22 +73,6 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
@ -101,13 +82,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"value_type": fields.String(attribute=_serialize_variable_type),
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
full_content=fields.Raw(attribute=_serialize_full_content),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
@ -156,7 +135,6 @@ def _api_prerequisite(f):
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
assert isinstance(current_user, Account)
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -6,11 +6,9 @@ from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
assert isinstance(current_user, Account)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
|
||||
@ -13,8 +13,6 @@ from libs.oauth_data_source import NotionOAuth
|
||||
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
@ -81,8 +79,8 @@ class OAuthDataSourceBinding(Resource):
|
||||
return {"error": "Invalid code"}, 400
|
||||
try:
|
||||
oauth_provider.get_access_token(code)
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
@ -104,8 +102,8 @@ class OAuthDataSourceSync(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
try:
|
||||
oauth_provider.sync_data_source(binding_id)
|
||||
except requests.HTTPError as e:
|
||||
logger.exception(
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logging.exception(
|
||||
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
|
||||
)
|
||||
return {"error": "OAuth data source process failed"}, 400
|
||||
|
||||
@ -55,12 +55,6 @@ class EmailOrPasswordMismatchError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class AuthenticationFailedError(BaseHTTPException):
|
||||
error_code = "authentication_failed"
|
||||
description = "Invalid email or password."
|
||||
code = 401
|
||||
|
||||
|
||||
class EmailPasswordLoginLimitError(BaseHTTPException):
|
||||
error_code = "email_code_login_limit"
|
||||
description = "Too many incorrect password attempts. Please try again later."
|
||||
|
||||
@ -9,8 +9,8 @@ from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
EmailOrPasswordMismatchError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
@ -79,7 +79,7 @@ class LoginApi(Resource):
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise AuthenticationFailedError()
|
||||
raise EmailOrPasswordMismatchError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
@ -130,9 +130,8 @@ class ResetPasswordSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
@ -162,7 +161,7 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
language = "en-US"
|
||||
try:
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
if account is None:
|
||||
@ -200,7 +199,7 @@ class EmailCodeLoginApi(Resource):
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
try:
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
except AccountRegisterError:
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@ -222,8 +221,8 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
return NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError as are:
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
|
||||
@ -24,8 +24,6 @@ from services.feature_service import FeatureService
|
||||
|
||||
from .. import api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_oauth_providers():
|
||||
with current_app.app_context():
|
||||
@ -80,9 +78,9 @@ class OAuthCallback(Resource):
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
except requests.RequestException as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
error_text = e.response.text if e.response else str(e)
|
||||
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
|
||||
@ -1,202 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
import flask_login
|
||||
from flask import jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models.model import OAuthProviderApp
|
||||
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
|
||||
|
||||
from .. import api
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_id", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
client_id = parsed_args.get("client_id")
|
||||
if not client_id:
|
||||
raise BadRequest("client_id is required")
|
||||
|
||||
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
|
||||
if not oauth_provider_app:
|
||||
raise NotFound("client_id is invalid")
|
||||
|
||||
return view(self, oauth_provider_app, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
response = jsonify({"error": "Authorization header is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
parts = authorization_header.strip().split(None, 1)
|
||||
if len(parts) != 2:
|
||||
response = jsonify({"error": "Invalid Authorization header format"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
response = jsonify({"error": "token_type is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
response = jsonify({"error": "access_token is required"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
|
||||
if not account:
|
||||
response = jsonify({"error": "access_token or client_id is invalid"})
|
||||
response.status_code = 401
|
||||
response.headers["WWW-Authenticate"] = "Bearer"
|
||||
return response
|
||||
|
||||
return view(self, oauth_provider_app, account, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
class OAuthServerAppApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("redirect_uri", type=str, required=True, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
redirect_uri = parsed_args.get("redirect_uri")
|
||||
|
||||
# check if redirect_uri is valid
|
||||
if redirect_uri not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"app_icon": oauth_provider_app.app_icon,
|
||||
"app_label": oauth_provider_app.app_label,
|
||||
"scope": oauth_provider_app.scope,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAuthorizeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
user_account_id = account.id
|
||||
|
||||
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"code": code,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserTokenApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("grant_type", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=False, location="json")
|
||||
parser.add_argument("client_secret", type=str, required=False, location="json")
|
||||
parser.add_argument("redirect_uri", type=str, required=False, location="json")
|
||||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
raise BadRequest("code is required")
|
||||
|
||||
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
|
||||
raise BadRequest("client_secret is invalid")
|
||||
|
||||
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
|
||||
raise BadRequest("redirect_uri is invalid")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
|
||||
if not parsed_args["refresh_token"]:
|
||||
raise BadRequest("refresh_token is required")
|
||||
|
||||
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
|
||||
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
|
||||
)
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
@setup_required
|
||||
@oauth_server_client_id_required
|
||||
@oauth_server_access_token_required
|
||||
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
"avatar": account.avatar,
|
||||
"interface_language": account.interface_language,
|
||||
"timezone": account.timezone,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(OAuthServerAppApi, "/oauth/provider")
|
||||
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
|
||||
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
|
||||
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
|
||||
@ -1,9 +1,9 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import Account
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
|
||||
@ -17,10 +17,9 @@ class Subscription(Resource):
|
||||
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
|
||||
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
|
||||
args = parser.parse_args()
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
return BillingService.get_subscription(
|
||||
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
|
||||
)
|
||||
@ -32,9 +31,7 @@ class Invoices(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
assert isinstance(current_user, Account)
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
assert current_user.current_tenant_id is not None
|
||||
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -11,10 +9,7 @@ from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
@ -23,7 +18,6 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@ -118,18 +112,6 @@ class DataSourceNotionListApi(Resource):
|
||||
@marshal_with(integrate_notion_info_list_fields)
|
||||
def get(self):
|
||||
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
# import notion in the exist dataset
|
||||
@ -153,49 +135,31 @@ class DataSourceNotionListApi(Resource):
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="langgenius/notion_datasource/notion_datasource",
|
||||
datasource_name="notion_datasource",
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
if credential:
|
||||
datasource_runtime.runtime.credentials = credential
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_pages(
|
||||
user_id=current_user.id,
|
||||
datasource_parameters={},
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
)
|
||||
try:
|
||||
pages = []
|
||||
workspace_info = {}
|
||||
for message in online_document_result:
|
||||
result = message.result
|
||||
for info in result:
|
||||
workspace_info = {
|
||||
"workspace_id": info.workspace_id,
|
||||
"workspace_name": info.workspace_name,
|
||||
"workspace_icon": info.workspace_icon,
|
||||
}
|
||||
for page in info.pages:
|
||||
page_info = {
|
||||
"page_id": page.page_id,
|
||||
"page_name": page.page_name,
|
||||
"type": page.type,
|
||||
"parent_id": page.parent_id,
|
||||
"is_bound": page.page_id in exist_page_ids,
|
||||
"page_icon": page.page_icon,
|
||||
}
|
||||
pages.append(page_info)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return {"notion_info": {**workspace_info, "pages": pages}}, 200
|
||||
).all()
|
||||
if not data_source_bindings:
|
||||
return {"notion_info": []}, 200
|
||||
pre_import_info_list = []
|
||||
for data_source_binding in data_source_bindings:
|
||||
source_info = data_source_binding.source_info
|
||||
pages = source_info["pages"]
|
||||
# Filter out already bound pages
|
||||
for page in pages:
|
||||
if page["page_id"] in exist_page_ids:
|
||||
page["is_bound"] = True
|
||||
else:
|
||||
page["is_bound"] = False
|
||||
pre_import_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
}
|
||||
pre_import_info_list.append(pre_import_info)
|
||||
return {"notion_info": pre_import_info_list}, 200
|
||||
|
||||
|
||||
class DataSourceNotionApi(Resource):
|
||||
@ -203,25 +167,27 @@ class DataSourceNotionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
credential_id = request.args.get("credential_id", default=None, type=str)
|
||||
if not credential_id:
|
||||
raise ValueError("Credential id is required.")
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
with Session(db.engine) as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not data_source_binding:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
extractor = NotionExtractor(
|
||||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=credential.get("integration_secret"),
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
|
||||
@ -246,12 +212,10 @@ class DataSourceNotionApi(Resource):
|
||||
extract_settings = []
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
|
||||
@ -19,9 +19,9 @@ from controllers.console.wraps import (
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
@ -31,7 +31,6 @@ from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
@ -280,15 +279,6 @@ class DatasetApi(Resource):
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
@ -432,21 +422,17 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
if file_details:
|
||||
for file_detail in file_details:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value,
|
||||
upload_file=file_detail,
|
||||
document_model=args["doc_form"],
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args["info_list"]["data_source_type"] == "notion_import":
|
||||
notion_info_list = args["info_list"]["notion_info_list"]
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info["workspace_id"]
|
||||
credential_id = notion_info.get("credential_id")
|
||||
for page in notion_info["pages"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": credential_id,
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page["page_id"],
|
||||
"notion_page_type": page["type"],
|
||||
@ -459,7 +445,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
website_info_list = args["info_list"]["website_info_list"]
|
||||
for url in website_info_list["urls"]:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": website_info_list["provider"],
|
||||
"job_id": website_info_list["job_id"],
|
||||
@ -567,7 +553,7 @@ class DatasetIndexingStatusApi(Resource):
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data, 200
|
||||
return data
|
||||
|
||||
|
||||
class DatasetApiKeyApi(Resource):
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from typing import Literal, cast
|
||||
@ -41,7 +40,6 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import (
|
||||
@ -53,12 +51,9 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
@ -357,6 +352,9 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
@ -428,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
|
||||
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
|
||||
)
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
@ -470,15 +468,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
data_process_rule_dict = data_process_rule.to_dict()
|
||||
info_list = []
|
||||
extract_settings = []
|
||||
for document in documents:
|
||||
if document.indexing_status in {"completed", "error"}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
data_source_info = document.data_source_info_dict
|
||||
# format document files info
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
info_list.append(file_id)
|
||||
# format document notion info
|
||||
elif (
|
||||
data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info
|
||||
):
|
||||
pages = []
|
||||
page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]}
|
||||
pages.append(page)
|
||||
notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages}
|
||||
info_list.append(notion_info)
|
||||
|
||||
if document.data_source_type == "upload_file":
|
||||
if not data_source_info:
|
||||
continue
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
@ -490,17 +500,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
raise NotFound("File not found.")
|
||||
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
|
||||
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
elif document.data_source_type == "notion_import":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.NOTION.value,
|
||||
datasource_type="notion_import",
|
||||
notion_info={
|
||||
"credential_id": data_source_info["credential_id"],
|
||||
"notion_workspace_id": data_source_info["notion_workspace_id"],
|
||||
"notion_obj_id": data_source_info["notion_page_id"],
|
||||
"notion_page_type": data_source_info["type"],
|
||||
@ -510,10 +517,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == "website_crawl":
|
||||
if not data_source_info:
|
||||
continue
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type=DatasourceType.WEBSITE.value,
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": data_source_info["provider"],
|
||||
"job_id": data_source_info["job_id"],
|
||||
@ -656,7 +661,7 @@ class DocumentApi(DocumentResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@ -961,7 +966,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise DocumentAlreadyFinishedError()
|
||||
retry_documents.append(document)
|
||||
except Exception:
|
||||
logger.exception("Failed to retry document, document id: %s", document_id)
|
||||
logging.exception("Failed to retry document, document id: %s", document_id)
|
||||
continue
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
@ -1019,41 +1024,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
@ -1075,6 +1045,3 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@ -71,9 +71,3 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
||||
@ -23,8 +23,6 @@ from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
@ -83,5 +81,5 @@ class DatasetsHitTestingBase:
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Hit testing failed.")
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
|
||||
@ -1,362 +0,0 @@
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import make_response, redirect, request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
oauth_config = DatasourceProviderService().get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_config:
|
||||
raise ValueError(f"No OAuth Client Config for {provider_id}")
|
||||
|
||||
context_id = OAuthProxyService.create_proxy_context(
|
||||
user_id=current_user.id,
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
authorization_url_response = oauth_handler.get_authorization_url(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_id=plugin_id,
|
||||
provider=provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_config,
|
||||
)
|
||||
response = make_response(jsonable_encoder(authorization_url_response))
|
||||
response.set_cookie(
|
||||
"context_id",
|
||||
context_id,
|
||||
httponly=True,
|
||||
samesite="Lax",
|
||||
max_age=OAuthProxyService.__MAX_AGE__,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class DatasourceOAuthCallback(Resource):
|
||||
@setup_required
|
||||
def get(self, provider_id: str):
|
||||
context_id = request.cookies.get("context_id") or request.args.get("context_id")
|
||||
if not context_id:
|
||||
raise Forbidden("context_id not found")
|
||||
|
||||
context = OAuthProxyService.use_proxy_context(context_id)
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
oauth_client_params = datasource_provider_service.get_oauth_client(
|
||||
tenant_id=tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
if not oauth_client_params:
|
||||
raise NotFound()
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_response = oauth_handler.get_credentials(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
redirect_uri=redirect_uri,
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
avatar_url=oauth_response.metadata.get("avatar_url") or None,
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
|
||||
)
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_id: str):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOAuthCallback,
|
||||
"/oauth/plugin/<path:provider_id>/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource/<path:provider_id>",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDeleteApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/delete",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthListApi,
|
||||
"/auth/plugin/datasource/list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceHardCodeAuthListApi,
|
||||
"/auth/plugin/datasource/default-list",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthOauthCustomClient,
|
||||
"/auth/plugin/datasource/<path:provider_id>/custom-client",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthDefaultApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/default",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceUpdateProviderNameApi,
|
||||
"/auth/plugin/datasource/<path:provider_id>/update-name",
|
||||
)
|
||||
@ -1,57 +0,0 @@
|
||||
from flask_restx import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
parser.add_argument("credential_id", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
credential_id=args.get("credential_id"),
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
@ -1,164 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
knowledge_pipeline_publish_enabled,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
@ -1,114 +0,0 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restx import Resource, marshal, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=args["yaml_content"],
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
|
||||
name="",
|
||||
description="",
|
||||
icon_info=IconInfo(
|
||||
icon="📙",
|
||||
icon_background="#FFF4ED",
|
||||
icon_type="emoji",
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
@ -1,389 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.workflow_draft_variable import (
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
)
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
if not isinstance(current_user, Account) or not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class RagPipelineVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(pipeline.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(pipeline.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class RagPipelineSystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, pipeline: Pipeline):
|
||||
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class RagPipelineEnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, pipeline: Pipeline):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
RagPipelineVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
|
||||
)
|
||||
@ -1,147 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
dataset_name=args.get("name"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
62
api/controllers/console/datasets/upload_file.py
Normal file
62
api/controllers/console/datasets/upload_file.py
Normal file
@ -0,0 +1,62 @@
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(Resource):
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@ -1,47 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user is not an account")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@ -26,8 +26,6 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatAudioApi(InstalledAppResource):
|
||||
def post(self, installed_app):
|
||||
@ -40,7 +38,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -61,7 +59,7 @@ class ChatAudioApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -85,7 +83,7 @@ class ChatTextApi(InstalledAppResource):
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except NoAudioUploadedServiceError:
|
||||
raise NoAudioUploadedError()
|
||||
@ -106,5 +104,5 @@ class ChatTextApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
@ -32,8 +32,6 @@ from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@ -67,7 +65,7 @@ class CompletionApi(InstalledAppResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -80,7 +78,7 @@ class CompletionApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -127,7 +125,7 @@ class ChatApi(InstalledAppResource):
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
logging.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -142,7 +140,7 @@ class ChatApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ class ConversationApi(InstalledAppResource):
|
||||
ConversationService.delete(app_model, conversation_id, current_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
WebConversationService.unpin(app_model, conversation_id, current_user)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ from services.errors.message import (
|
||||
)
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListApi(InstalledAppResource):
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
@ -128,7 +126,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -160,7 +158,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
return {"data": questions}
|
||||
|
||||
@ -43,8 +43,6 @@ class ExploreAppMetaApi(InstalledAppResource):
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Get app meta"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise ValueError("App not found")
|
||||
return AppService().get_app_meta(app_model)
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from libs import helper
|
||||
from libs.login import current_user
|
||||
from models.model import AppMode, InstalledApp
|
||||
@ -36,8 +35,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
Run workflow
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
@ -46,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
@ -66,7 +63,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@ -76,18 +73,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
|
||||
Stop workflow task
|
||||
"""
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
assert current_user is not None
|
||||
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, Optional, ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource
|
||||
@ -15,15 +13,19 @@ from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
def installed_app_required(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
if not kwargs.get("installed_app_id"):
|
||||
raise ValueError("missing installed_app_id in path parameters")
|
||||
|
||||
installed_app_id = kwargs.get("installed_app_id")
|
||||
installed_app_id = str(installed_app_id)
|
||||
|
||||
del kwargs["installed_app_id"]
|
||||
|
||||
installed_app = (
|
||||
db.session.query(InstalledApp)
|
||||
.where(
|
||||
@ -50,10 +52,10 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P],
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
def user_allowed_to_access_app(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(installed_app: InstalledApp, *args, **kwargs):
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import FileService
|
||||
@ -69,7 +68,7 @@ class FileApi(Resource):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -90,7 +89,7 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id)
|
||||
text = FileService.get_file_preview(file_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,6 @@ from controllers.common.errors import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
@ -62,7 +61,7 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
upload_file = FileService.upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.schemas.schema_manager import SchemaManager
|
||||
from libs.login import login_required
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpecSchemaDefinitionsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""
|
||||
Get system JSON Schema definitions specification
|
||||
Used for frontend component type mapping
|
||||
"""
|
||||
try:
|
||||
schema_manager = SchemaManager()
|
||||
schema_definitions = schema_manager.get_all_schema_definitions()
|
||||
return schema_definitions, 200
|
||||
except Exception:
|
||||
logger.exception("Failed to get schema definitions from local registry")
|
||||
# Return empty array as fallback
|
||||
return [], 200
|
||||
|
||||
|
||||
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")
|
||||
@ -9,8 +9,6 @@ from configs import dify_config
|
||||
|
||||
from . import api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VersionApi(Resource):
|
||||
def get(self):
|
||||
@ -36,7 +34,7 @@ class VersionApi(Resource):
|
||||
try:
|
||||
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
|
||||
except Exception as error:
|
||||
logger.warning("Check update version error: %s.", str(error))
|
||||
logging.warning("Check update version error: %s.", str(error))
|
||||
result["version"] = args.get("current_version")
|
||||
return result
|
||||
|
||||
@ -57,7 +55,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
|
||||
# Compare versions
|
||||
return latest > current
|
||||
except version.InvalidVersion:
|
||||
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||
logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
@ -9,17 +7,14 @@ from werkzeug.exceptions import Forbidden
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args, **kwargs):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import Account, TenantAccountRole
|
||||
from models.account import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
@ -15,12 +15,10 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
@ -66,12 +64,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
assert tenant_id is not None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource):
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("emails", type=list, required=True, location="json")
|
||||
parser.add_argument("emails", type=str, required=True, location="json", action="append")
|
||||
parser.add_argument("role", type=str, required=True, default="admin", location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.helper import StrLen, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
from services.model_provider_service import ModelProviderService
|
||||
@ -46,109 +45,12 @@ class ModelProviderCredentialApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider: str):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
credentials = model_provider_service.get_provider_credential(
|
||||
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
|
||||
)
|
||||
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
|
||||
|
||||
return {"credentials": credentials}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.create_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.update_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credentials=args["credentials"],
|
||||
credential_id=args["credential_id"],
|
||||
credential_name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderCredentialSwitchApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
service = ModelProviderService()
|
||||
service.switch_active_provider_credential(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
credential_id=args["credential_id"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ModelProviderValidateApi(Resource):
|
||||
@setup_required
|
||||
@ -167,7 +69,7 @@ class ModelProviderValidateApi(Resource):
|
||||
error = ""
|
||||
|
||||
try:
|
||||
model_provider_service.validate_provider_credentials(
|
||||
model_provider_service.provider_credentials_validate(
|
||||
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
@ -182,6 +84,42 @@ class ModelProviderValidateApi(Resource):
|
||||
return response
|
||||
|
||||
|
||||
class ModelProviderApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
try:
|
||||
model_provider_service.save_provider_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider: str):
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ModelProviderIconApi(Resource):
|
||||
"""
|
||||
Get model provider icon
|
||||
@ -249,10 +187,8 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
|
||||
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
|
||||
|
||||
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
|
||||
api.add_resource(
|
||||
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
|
||||
)
|
||||
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
|
||||
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
|
||||
|
||||
api.add_resource(
|
||||
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user