Compare commits

..

83 Commits

Author SHA1 Message Date
bdfb31f332 chore(api): ignore google.cloud.storage in MyPy 2025-08-26 12:45:04 +08:00
f762c969b3 chore(api): resolve MyPy issues 2025-08-26 12:44:25 +08:00
6b87b0fa42 Merge remote-tracking branch 'upstream/feat/support-bool-variable-fe' into feat/boolean-type 2025-08-26 10:07:32 +08:00
4386c9f4f1 main 2025-08-25 10:35:52 +08:00
62955f4d11 chore: hide json entrance 2025-08-25 10:26:11 +08:00
e38ba8403d feat: workflow support input json 2025-08-22 14:54:34 +08:00
a1f278bdc8 fix: one step run form 2025-08-22 14:43:23 +08:00
65c0f8d657 feat: webapp support json field 2025-08-22 14:36:12 +08:00
c3e4d5a059 feat: support choose and save deep obj 2025-08-21 15:31:34 +08:00
30775109fd feat: can edit and choose obj 2025-08-21 14:47:29 +08:00
76659843d7 fix: item value type 2025-08-20 16:32:29 +08:00
42258bf451 chore: node tracing node value 2025-08-20 16:24:53 +08:00
399d510e1b fix(api): handle array[boolean] in LoopNode
The frontend save `array[boolean]` as a JSON array with string
element (e.g. `["true", "false", "true"]`), which is not handled
by LoopNode correctly.

This commit fix the issue by converting each element into Python
bool type before building segment.
2025-08-20 15:33:19 +08:00
8462c79179 fix: can not use converstation vars in var aggregator node 2025-08-20 14:59:58 +08:00
f5f11f5a9e feat: basic input field edit 2025-08-20 14:23:33 +08:00
96e8321462 chore: add var can not has exist key 2025-08-20 14:06:46 +08:00
af6e5e8663 chore: edit filed exists 2025-08-20 14:00:38 +08:00
d4f976270d fix: chore: if boolean array bug 2025-08-19 16:16:10 +08:00
dc3cbd9a74 chore: handle duliplicate keys 2025-08-19 15:57:18 +08:00
056564c100 merge main 2025-08-19 15:48:01 +08:00
d5b99610fc chore: iteratino input support bool list 2025-08-07 10:18:45 +08:00
04d2b0775d fix: chatbot bool 2025-08-04 11:14:33 +08:00
6614e97b53 fix: basic chatbot bool 2025-08-04 10:59:20 +08:00
ca37e304cc fix: basic bool error 2025-08-04 10:53:35 +08:00
437729db97 feat: support var arrs 2025-08-01 18:36:18 +08:00
fbd0effa04 chore: handle bool value in vars 2025-08-01 18:26:58 +08:00
14c62850e0 fix: list filter operation filter boolean problem 2025-08-01 18:06:27 +08:00
c4824a1d4c fix: assigner set false can not public 2025-08-01 17:55:20 +08:00
d642e9a98e chore: fix can not publish 2025-08-01 17:39:25 +08:00
72c6195c10 fix: list operation boolean default 2025-08-01 17:32:09 +08:00
2bd096b454 fix: con vars set value 2025-08-01 17:24:53 +08:00
fa98505e09 fix(api): enhance compatibility with existing workflow
Fix issues flagged by unit tests.
2025-08-01 15:27:33 +08:00
65ae25e09b test(api): add test cases for ParameterConfig validation logic 2025-08-01 15:25:28 +08:00
4943d8bc93 test(api): Introduce test cases for ParameterExtractorNode._validate_result 2025-08-01 15:25:28 +08:00
168d82f9b0 feat(api): support boolean types in parameter extractor node 2025-08-01 15:25:28 +08:00
0e3ccb4dcc feat(api): update boolean type handling in ListOperatorNode 2025-08-01 13:11:24 +08:00
f7f6210a4c fix(api): Fix incorrect handling of array types in SegmentType.is_valid
Add comphensive tests for the behavior of the `is_valid` method.
2025-08-01 13:09:46 +08:00
ed8e06d1bf fix(api): Allow bool value in the FilterCondition of list operator 2025-07-31 12:24:18 +08:00
6287e3cd9e fix: boolean can picker var 2025-07-31 10:28:17 +08:00
94c33c6eed chore: boolean string to boolean 2025-07-30 17:56:59 +08:00
383695b820 chore(api): put import for type annotation inside TYPE_CHECKING in llm/nodes.py
this causes unnecessary dependencies and cycle imports.
2025-07-30 16:13:04 +08:00
3751702efc feat(api): Implement boolean types support for list operator node 2025-07-30 10:58:18 +08:00
111c34e50f feat(api): Implement boolean types support for Code Node 2025-07-30 10:56:51 +08:00
a2246be4f5 feat(api): Implement checkbox input support in Workflow and Chat App 2025-07-30 10:56:51 +08:00
80e08562be feat(api): Initial support for boolean / array[boolean] types 2025-07-30 10:56:51 +08:00
6d03a15e0f chore: start form boolean to checkbox 2025-07-30 10:53:12 +08:00
bd04ddd544 merge main 2025-07-30 10:26:31 +08:00
cb51360aa3 chore: remove arr bool 2025-07-29 11:34:16 +08:00
269d34304e fix : loop var default value 2025-07-28 16:17:52 +08:00
590c14c977 chore: conveersion pass boolean string boolean 2025-07-28 16:02:50 +08:00
5fbeb275b2 fix: loop not choose bool var value 2025-07-28 15:20:18 +08:00
ff9d051635 chore: fix if not value set value 2025-07-28 14:49:44 +08:00
1e18c828d9 chore: remove transform boolean to string in editor 2025-07-28 14:09:34 +08:00
5b895be412 feat: template support bools input 2025-07-28 13:46:55 +08:00
ea2a1b203c fix: list operator bool not show the bool input 2025-07-28 11:20:22 +08:00
7bfdfe73c2 chore: list filter bool value 2025-07-28 11:09:37 +08:00
171c5055f8 fix: loop arry bool inner item type value 2025-07-25 14:52:13 +08:00
e3e4369358 fix: webapp get boolean value 2025-07-25 14:05:14 +08:00
efa28453be fix: if value show 2025-07-24 17:54:22 +08:00
37de3b1b68 fix: single run bool type 2025-07-24 17:39:10 +08:00
75d4e519fc chore: loop false can run 2025-07-24 16:06:13 +08:00
cf63926e16 chore: assigner node 2025-07-24 15:59:44 +08:00
14eb5b7930 feat: list operator support bool 2025-07-24 15:15:35 +08:00
544ebde054 chore: if bool value problem 2025-07-24 15:05:09 +08:00
6012ad57ac feat: code support boolean 2025-07-24 14:35:41 +08:00
70f6d8b42a main 2025-07-24 14:10:29 +08:00
1d738f3fa6 feat: chat support bool 2025-07-09 16:48:17 +08:00
912d68a148 chore: competion webapp 2025-07-09 14:56:11 +08:00
8c8c250570 feat: completion support boolean 2025-07-09 14:28:36 +08:00
b5782fff8f feat: basic chat app support bool 2025-07-09 14:06:25 +08:00
f2dfb5363f feat: one step run support boolean 2025-07-09 11:23:21 +08:00
fc0dea5647 feat: workflow debug add bool 2025-07-09 10:53:52 +08:00
fcf0ae4c85 feat: extract parameter support bool 2025-07-08 16:41:58 +08:00
9ac629bcf5 feat: if node support bool 2025-07-08 16:30:54 +08:00
79bdb22f79 feat: loop termin support bool 2025-07-08 16:06:56 +08:00
b72d977871 feat: loop var support bool type 2025-07-08 15:23:24 +08:00
657d3d1dae fix: boolean problem 2025-07-08 14:34:18 +08:00
3ba23238e5 chore: boolarray in varibale 2025-07-04 17:47:44 +08:00
e1a7b59160 chore: missing files 2025-07-04 16:48:56 +08:00
18941beb34 feat: converstation support bool 2025-07-04 16:48:33 +08:00
6ff18d13e6 feat: struct output boolean 2025-07-04 16:12:47 +08:00
1094b3da23 feat: add var type 2025-07-04 16:02:28 +08:00
77aa35ff15 feat: add boolean type 2025-07-04 15:27:52 +08:00
2906 changed files with 31615 additions and 137150 deletions

View File

@ -1,19 +0,0 @@
{
"permissions": {
"allow": [],
"deny": []
},
"env": {
"__comment": "Environment variables for MCP servers. Override in .claude/settings.local.json with actual values.",
"GITHUB_PERSONAL_ACCESS_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
},
"enabledMcpjsonServers": [
"context7",
"sequential-thinking",
"github",
"fetch",
"playwright",
"ide"
],
"enableAllProjectMcpServers": true
}

View File

@ -1,6 +1,6 @@
#!/bin/bash
corepack enable
npm add -g pnpm@10.15.0
cd web && pnpm install
pipx install uv

34
.github/actions/setup-uv/action.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: Setup UV and Python
inputs:
python-version:
description: Python version to use and the UV installed with
required: true
default: '3.12'
uv-version:
description: UV version to set up
required: true
default: '0.8.9'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
default: ''
enable-cache:
required: true
default: true
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}
enable-cache: ${{ inputs.enable-cache }}
cache-dependency-glob: ${{ inputs.uv-lockfile }}

View File

@ -1,7 +1,13 @@
name: Run Pytest
on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/**
- docker/**
- .github/workflows/api-tests.yml
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
@ -27,11 +33,10 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: ./.github/actions/setup-uv
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
cache-dependency-glob: api/uv.lock
uv-lockfile: api/uv.lock
- name: Check UV lockfile
run: uv lock --project api --check
@ -42,7 +47,11 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Run ty check
run: |
cd api
uv add --dev ty
uv run ty check || true
- name: Run pyrefly check
run: |
cd api
@ -62,6 +71,15 @@ jobs:
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
- name: MyPy Cache
uses: actions/cache@v4
with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env

View File

@ -1,7 +1,9 @@
name: autofix.ci
on:
workflow_call:
pull_request:
branches: ["main"]
push:
branches: [ "main" ]
permissions:
contents: read
@ -13,9 +15,7 @@ jobs:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.12"
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
- run: |
cd api
uv sync --dev
@ -26,7 +26,6 @@ jobs:
- name: ast-grep
run: |
uvx --from ast-grep-cli sg --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all
uvx --from ast-grep-cli sg --pattern 'session.query($WHATEVER).filter($HERE)' --rewrite 'session.query($WHATEVER).where($HERE)' -l py --update-all
- name: mdformat
run: |
uvx mdformat .

View File

@ -8,8 +8,6 @@ on:
- "deploy/enterprise"
- "build/**"
- "release/e-*"
- "deploy/rag-dev"
- "feat/rag-2"
tags:
- "*"

View File

@ -1,7 +1,13 @@
name: DB Migration Test
on:
workflow_call:
pull_request:
branches:
- main
- plugins/beta
paths:
- api/migrations/**
- .github/workflows/db-migration-test.yml
concurrency:
group: db-migration-test-${{ github.ref }}
@ -19,20 +25,12 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: ./.github/actions/setup-uv
with:
enable-cache: true
python-version: "3.12"
cache-dependency-glob: api/uv.lock
uv-lockfile: api/uv.lock
- name: Install dependencies
run: uv sync --project api
- name: Ensure Offline migration are supported
run: |
# upgrade
uv run --directory api flask db upgrade 'base:head' --sql
# downgrade
uv run --directory api flask db downgrade 'head:base' --sql
- name: Prepare middleware env
run: |

View File

@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/dev"
types:
- completed
@ -12,13 +12,12 @@ jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.conclusion == 'success'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@ -1,78 +0,0 @@
name: Main CI Pipeline
on:
pull_request:
branches: ["main"]
push:
branches: ["main"]
permissions:
contents: write
pull-requests: write
checks: write
statuses: write
concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
# Check which paths were changed to determine which tests to run
check-changes:
name: Check Changed Files
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
web-changed: ${{ steps.changes.outputs.web }}
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: changes
with:
filters: |
api:
- 'api/**'
- 'docker/**'
- '.github/workflows/api-tests.yml'
web:
- 'web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'
- '.github/workflows/vdb-tests.yml'
- 'api/uv.lock'
- 'api/pyproject.toml'
migration:
- 'api/migrations/**'
- '.github/workflows/db-migration-test.yml'
# Run tests in parallel
api-tests:
name: API Tests
needs: check-changes
if: needs.check-changes.outputs.api-changed == 'true'
uses: ./.github/workflows/api-tests.yml
web-tests:
name: Web Tests
needs: check-changes
if: needs.check-changes.outputs.web-changed == 'true'
uses: ./.github/workflows/web-tests.yml
style-check:
name: Style Check
uses: ./.github/workflows/style.yml
vdb-tests:
name: VDB Tests
needs: check-changes
if: needs.check-changes.outputs.vdb-changed == 'true'
uses: ./.github/workflows/vdb-tests.yml
db-migration-test:
name: DB Migration Test
needs: check-changes
if: needs.check-changes.outputs.migration-changed == 'true'
uses: ./.github/workflows/db-migration-test.yml

View File

@ -1,7 +1,9 @@
name: Style check
on:
workflow_call:
pull_request:
branches:
- main
concurrency:
group: style-${{ github.head_ref || github.run_id }}
@ -12,6 +14,7 @@ permissions:
statuses: write
contents: read
jobs:
python-style:
name: Python Style
@ -33,28 +36,30 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v6
uses: ./.github/actions/setup-uv
with:
uv-lockfile: api/uv.lock
enable-cache: false
python-version: "3.12"
cache-dependency-glob: api/uv.lock
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
run: uv sync --project api --dev
- name: Run Basedpyright Checks
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: dev/basedpyright-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints
if: failure()
run: echo "Please run 'dev/reformat' to fix the fixable linting errors."
web-style:
name: Web Style
runs-on: ubuntu-latest
@ -96,9 +101,7 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
pnpm run lint
pnpm run eslint
run: pnpm run lint
docker-compose-template:
name: Docker Compose Template

View File

@ -67,22 +67,12 @@ jobs:
working-directory: ./web
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Generate i18n type definitions
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run gen:i18n-types
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: Update i18n files and type definitions based on en-US changes
title: 'chore: translate i18n files and update type definitions'
body: |
This PR was automatically created to update i18n files and TypeScript type definitions based on changes in en-US locale.
**Changes included:**
- Updated translation files for all locales
- Regenerated TypeScript type definitions for type safety
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@ -1,7 +1,15 @@
name: Run VDB Tests
on:
workflow_call:
pull_request:
branches:
- main
paths:
- api/core/rag/datasource/**
- docker/**
- .github/workflows/vdb-tests.yml
- api/uv.lock
- api/pyproject.toml
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}
@ -31,11 +39,10 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v6
uses: ./.github/actions/setup-uv
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
cache-dependency-glob: api/uv.lock
uv-lockfile: api/uv.lock
- name: Check UV lockfile
run: uv lock --project api --check

View File

@ -1,7 +1,11 @@
name: Web Tests
on:
workflow_call:
pull_request:
branches:
- main
paths:
- web/**
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}
@ -47,11 +51,6 @@ jobs:
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web

13
.gitignore vendored
View File

@ -123,12 +123,10 @@ venv.bak/
# mkdocs documentation
/site
# type checking
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
pyrightconfig.json
!api/pyrightconfig.json
# Pyre type checker
.pyre/
@ -197,8 +195,8 @@ sdks/python-client/dify_client.egg-info
.vscode/*
!.vscode/launch.json.template
!.vscode/README.md
pyrightconfig.json
api/.vscode
web/.vscode
# vscode Code History Extension
.history
@ -216,13 +214,6 @@ mise.toml
# Next.js build output
.next/
# PWA generated files
web/public/sw.js
web/public/sw.js.map
web/public/workbox-*.js
web/public/workbox-*.js.map
web/public/fallback-*.js
# AI Assistant
.roo/
api/.env.backup

View File

@ -1,34 +0,0 @@
{
"mcpServers": {
"context7": {
"type": "http",
"url": "https://mcp.context7.com/mcp"
},
"sequential-thinking": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-sequential-thinking"],
"env": {}
},
"github": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-github"],
"env": {
"GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_PERSONAL_ACCESS_TOKEN}"
}
},
"fetch": {
"type": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"env": {}
},
"playwright": {
"type": "stdio",
"command": "npx",
"args": ["-y", "@playwright/mcp@latest"],
"env": {}
}
}
}

View File

@ -1 +0,0 @@
CLAUDE.md

View File

@ -32,7 +32,7 @@ uv run --project api pytest tests/integration_tests/ # Integration tests
./dev/reformat # Run all formatters and linters
uv run --project api ruff check --fix ./ # Fix linting issues
uv run --project api ruff format ./ # Format code
uv run --directory api basedpyright # Type checking
uv run --project api mypy . # Type checking
```
### Frontend (Web)
@ -86,4 +86,3 @@ pnpm test # Run Jest tests
## Project-Specific Conventions
- All async tasks use Celery with Redis as broker
- **Internationalization**: Frontend supports multiple languages with English (`web/i18n/en-US/`) as the source. All user-facing text must use i18n keys, no hardcoded strings. Edit corresponding module files in `en-US/` directory for translations.

View File

@ -4,48 +4,6 @@ WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
# Backend Development Environment Setup
.PHONY: dev-setup prepare-docker prepare-web prepare-api
# Default dev setup target
dev-setup: prepare-docker prepare-web prepare-api
@echo "✅ Backend development environment setup complete!"
# Step 1: Prepare Docker middleware
prepare-docker:
@echo "🐳 Setting up Docker middleware..."
@cp -n docker/middleware.env.example docker/middleware.env 2>/dev/null || echo "Docker middleware.env already exists"
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev up -d
@echo "✅ Docker middleware started"
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@cd web && pnpm build
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
prepare-api:
@echo "🔧 Setting up API environment..."
@cp -n api/.env.example api/.env 2>/dev/null || echo "API .env already exists"
@cd api && uv sync --dev
@cd api && uv run flask db upgrade
@echo "✅ API environment prepared (not started)"
# Clean dev environment
dev-clean:
@echo "⚠️ Stopping Docker containers..."
@cd docker && docker compose -f docker-compose.middleware.yaml --env-file middleware.env -p dify-middlewares-dev down
@echo "🗑️ Removing volumes..."
@rm -rf docker/volumes/db
@rm -rf docker/volumes/redis
@rm -rf docker/volumes/plugin_daemon
@rm -rf docker/volumes/weaviate
@rm -rf api/storage
@echo "✅ Cleanup complete"
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
@ -81,21 +39,5 @@ build-push-web: build-web push-web
build-push-all: build-all push-all
@echo "All Docker images have been built and pushed."
# Help target
help:
@echo "Development Setup Targets:"
@echo " make dev-setup - Run all setup steps for backend dev environment"
@echo " make prepare-docker - Set up Docker middleware"
@echo " make prepare-web - Set up web environment"
@echo " make prepare-api - Set up API environment"
@echo " make dev-clean - Stop Docker middleware containers"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"
@echo " make build-api - Build API Docker image"
@echo " make build-all - Build all Docker images"
@echo " make push-all - Push all Docker images"
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all

View File

@ -180,7 +180,7 @@ docker compose up -d
## Contributing
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。
对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。
> 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。

View File

@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline
## Contributing
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
> Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c).

View File

@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @
## Contribuir
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md).
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias.
> Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c).

View File

@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart
## Contribuer
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md).
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences.
> Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c).

View File

@ -169,7 +169,7 @@ docker compose up -d
## 貢献
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。
> Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。

View File

@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
## 기여
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요.
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다.
> 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요.

View File

@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by
## Contribuindo
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md).
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências.
> Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c).

View File

@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
## Katkıda Bulunma
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz.
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün.
> Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın.

View File

@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
## 貢獻
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。
> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。

View File

@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De
## Đóng góp
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi.
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị.
> Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi.

View File

@ -75,7 +75,6 @@ DB_PASSWORD=difyai123456
DB_HOST=localhost
DB_PORT=5432
DB_DATABASE=dify
SQLALCHEMY_POOL_PRE_PING=true
# Storage configuration
# use for store upload files, private keys...
@ -461,16 +460,6 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
# Seconds of idle time before scaling down workers (default: 5.0)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
@ -575,7 +564,3 @@ QUEUE_MONITOR_THRESHOLD=200
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30
# Swagger UI configuration
SWAGGER_UI_ENABLED=true
SWAGGER_UI_PATH=/swagger-ui.html

View File

@ -1,112 +0,0 @@
[importlinter]
root_packages =
core
configs
controllers
models
tasks
services
[importlinter:contract:workflow]
name = Workflow
type=layers
layers =
graph_engine
graph_events
graph
nodes
node_events
entities
containers =
core.workflow
ignore_imports =
core.workflow.nodes.base.node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.loop.loop_node -> core.workflow.graph_events
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
core.workflow.nodes.node_factory -> core.workflow.graph
[importlinter:contract:rsc]
name = RSC
type = layers
layers =
graph_engine
response_coordinator
containers =
core.workflow.graph_engine
[importlinter:contract:worker]
name = Worker
type = layers
layers =
graph_engine
worker
containers =
core.workflow.graph_engine
[importlinter:contract:graph-engine-architecture]
name = Graph Engine Architecture
type = layers
layers =
graph_engine
orchestration
command_processing
event_management
error_handling
graph_traversal
state_management
worker_management
domain
containers =
core.workflow.graph_engine
[importlinter:contract:domain-isolation]
name = Domain Model Isolation
type = forbidden
source_modules =
core.workflow.graph_engine.domain
forbidden_modules =
core.workflow.graph_engine.worker_management
core.workflow.graph_engine.command_channels
core.workflow.graph_engine.layers
core.workflow.graph_engine.protocols
[importlinter:contract:worker-management]
name = Worker Management
type = forbidden
source_modules =
core.workflow.graph_engine.worker_management
forbidden_modules =
core.workflow.graph_engine.orchestration
core.workflow.graph_engine.command_processing
core.workflow.graph_engine.event_management
[importlinter:contract:error-handling-strategies]
name = Error Handling Strategies
type = independence
modules =
core.workflow.graph_engine.error_handling.abort_strategy
core.workflow.graph_engine.error_handling.retry_strategy
core.workflow.graph_engine.error_handling.fail_branch_strategy
core.workflow.graph_engine.error_handling.default_value_strategy
[importlinter:contract:graph-traversal-components]
name = Graph Traversal Components
type = layers
layers =
edge_processor
skip_propagator
containers =
core.workflow.graph_engine.graph_traversal
[importlinter:contract:command-channels]
name = Command Channels Independence
type = independence
modules =
core.workflow.graph_engine.command_channels.in_memory_channel
core.workflow.graph_engine.command_channels.redis_channel

View File

@ -43,7 +43,6 @@ select = [
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage
"G001", # don't use str format to logging messages
"G003", # don't use + in logging messages
"G004", # don't use f-strings to format logging messages
]

View File

@ -97,16 +97,8 @@ uv run celery -A app.celery beat
uv sync --dev
```
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md)
1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash
uv run pytest # Run all tests
uv run pytest tests/unit_tests/ # Unit tests only
uv run pytest tests/integration_tests/ # Integration tests
# Code quality
../dev/reformat # Run all formatters and linters
uv run ruff check --fix ./ # Fix linting issues
uv run ruff format ./ # Format code
uv run basedpyright . # Type checking
uv run -P api bash dev/pytest/pytest_all_tests.sh
```

View File

@ -1,3 +1,4 @@
import os
import sys
@ -16,20 +17,20 @@ else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app

View File

@ -5,8 +5,6 @@ from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp
logger = logging.getLogger(__name__)
# ----------------------------
# Application Factory Function
@ -25,9 +23,6 @@ def create_flask_app_with_configs() -> DifyApp:
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
return dify_app
@ -37,7 +32,7 @@ def create_app() -> DifyApp:
initialize_extensions(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return app
@ -98,14 +93,14 @@ def initialize_extensions(app: DifyApp):
is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True
if not is_enabled:
if dify_config.DEBUG:
logger.info("Skipped %s", short_name)
logging.info("Skipped %s", short_name)
continue
start_time = time.perf_counter()
ext.init_app(app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app():

View File

@ -1,22 +0,0 @@
import logging
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore
_logger = logging.getLogger(__name__)
def _log(message: str):
print(message, flush=True)
# grpc gevent
grpc_gevent.init_gevent()
_log("gRPC patched with gevent.")
pscycogreen_gevent.patch_psycopg()
_log("psycopg2 patched with gevent.")
from app import app, celery
__all__ = ["app", "celery"]

11
api/child_class.py Normal file
View File

@ -0,0 +1,11 @@
from tests.integration_tests.utils.parent_class import ParentClass
class ChildClass(ParentClass):
"""Test child class for module import helper tests"""
def __init__(self, name):
super().__init__(name)
def get_name(self):
return f"Child: {self.name}"

View File

@ -13,13 +13,11 @@ from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from constants.languages import languages
from core.helper import encrypter
from core.plugin.impl.plugin import PluginInstaller
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
@ -32,20 +30,14 @@ from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@ -577,7 +569,7 @@ def old_metadata_migration():
for document in documents:
if document.doc_metadata:
doc_metadata = document.doc_metadata
for key in doc_metadata:
for key, value in doc_metadata.items():
for field in BuiltInField:
if field.value == key:
break
@ -693,7 +685,7 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
logger.exception("Failed to execute database migration")
logging.exception("Failed to execute database migration")
finally:
lock.release()
else:
@ -741,7 +733,7 @@ where sites.id is null limit 1000"""
except Exception:
failed_app_ids.append(app_id)
click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red"))
logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id)
continue
if not processed_count:
@ -1239,17 +1231,15 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app, including associated file counts.
Count orphaned draft variables by app.
Returns:
Dictionary with statistics about orphaned variables and files
Dictionary with statistics about orphaned variables
"""
# Count orphaned variables by app
variables_query = """
query = """
SELECT
wdv.app_id,
COUNT(*) as variable_count,
COUNT(wdv.file_id) as file_count
COUNT(*) as variable_count
FROM workflow_draft_variables AS wdv
WHERE NOT EXISTS(
SELECT 1 FROM apps WHERE apps.id = wdv.app_id
@ -1259,21 +1249,14 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query))
orphaned_by_app = {}
total_files = 0
result = conn.execute(sa.text(query))
orphaned_by_app = {row[0]: row[1] for row in result}
for row in result:
app_id, variable_count, file_count = row
orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count}
total_files += file_count
total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values())
total_orphaned = sum(orphaned_by_app.values())
app_count = len(orphaned_by_app)
return {
"total_orphaned_variables": total_orphaned,
"total_orphaned_files": total_files,
"orphaned_app_count": app_count,
"orphaned_by_app": orphaned_by_app,
}
@ -1302,7 +1285,6 @@ def cleanup_orphaned_draft_variables(
stats = _count_orphaned_draft_variables()
logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"])
logger.info("Found %s associated offload files", stats["total_orphaned_files"])
logger.info("Across %s non-existent apps", stats["orphaned_app_count"])
if stats["total_orphaned_variables"] == 0:
@ -1311,10 +1293,10 @@ def cleanup_orphaned_draft_variables(
if dry_run:
logger.info("DRY RUN: Would delete the following:")
for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[
for app_id, count in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1], reverse=True)[
:10
]: # Show top 10
logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"])
logger.info(" App %s: %s variables", app_id, count)
if len(stats["orphaned_by_app"]) > 10:
logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10)
return
@ -1323,8 +1305,7 @@ def cleanup_orphaned_draft_variables(
if not force:
click.confirm(
f"Are you sure you want to delete {stats['total_orphaned_variables']} "
f"orphaned draft variables and {stats['total_orphaned_files']} associated files "
f"from {stats['orphaned_app_count']} apps?",
f"orphaned draft variables from {stats['orphaned_app_count']} apps?",
abort=True,
)
@ -1357,231 +1338,3 @@ def cleanup_orphaned_draft_variables(
continue
logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps)
@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_datasource_oauth_client(provider, client_params):
"""
Setup datasource oauth client
"""
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow"))
deleted_count = (
db.session.query(DatasourceOauthParamConfig)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow"))
oauth_client = DatasourceOauthParamConfig(
provider=provider_name,
plugin_id=plugin_id,
system_credentials=client_params_dict,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"provider: {provider_name}", fg="green"))
click.echo(click.style(f"plugin_id: {plugin_id}", fg="green"))
click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green"))
click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green"))
@click.command("transform-datasource-credentials", help="Transform datasource credentials.")
def transform_datasource_credentials():
"""
Transform datasource credentials
"""
try:
installer_manager = PluginInstaller()
plugin_migration = PluginMigration()
notion_plugin_id = "langgenius/notion_datasource"
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
jina_plugin_id = "langgenius/jina_datasource"
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
oauth_credential_type = CredentialType.OAUTH2
api_key_credential_type = CredentialType.API_KEY
# deal notion credentials
deal_notion_count = 0
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for credential in notion_credentials:
tenant_id = credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if notion_plugin_id not in installed_plugins_ids:
if notion_plugin_unique_identifier:
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential oauth params
access_token = credential.access_token
# notion info
notion_info = credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
new_credentials = {
"integration_secret": encrypter.encrypt_token(tenant_id, access_token),
"workspace_id": workspace_id,
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
}
datasource_provider = DatasourceProvider(
provider="notion_datasource",
tenant_id=tenant_id,
plugin_id=notion_plugin_id,
auth_type=oauth_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url=workspace_icon or "default",
is_default=False,
)
db.session.add(datasource_provider)
deal_notion_count += 1
db.session.commit()
# deal firecrawl credentials
deal_firecrawl_count = 0
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in firecrawl_credentials:
tenant_id = credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if firecrawl_plugin_id not in installed_plugins_ids:
if firecrawl_plugin_unique_identifier:
# install firecrawl plugin
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
"firecrawl_api_key": api_key,
"base_url": base_url,
}
datasource_provider = DatasourceProvider(
provider="firecrawl",
tenant_id=tenant_id,
plugin_id=firecrawl_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_firecrawl_count += 1
db.session.commit()
# deal jina credentials
deal_jina_count = 0
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in jina_credentials:
tenant_id = credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
if jina_plugin_id not in installed_plugins_ids:
if jina_plugin_unique_identifier:
# install jina plugin
print(jina_plugin_unique_identifier)
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,
}
datasource_provider = DatasourceProvider(
provider="jina",
tenant_id=tenant_id,
plugin_id=jina_plugin_id,
auth_type=api_key_credential_type.value,
encrypted_credentials=new_credentials,
name=f"Auth {auth_count}",
avatar_url="default",
is_default=False,
)
db.session.add(datasource_provider)
deal_jina_count += 1
db.session.commit()
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green"))
click.echo(
click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")
)
click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green"))
@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.")
@click.option(
"--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl"
)
@click.option(
"--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl"
)
@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100)
def install_rag_pipeline_plugins(input_file, output_file, workers):
"""
Install rag pipeline plugins
"""
click.echo(click.style("Installing rag pipeline plugins", fg="yellow"))
plugin_migration = PluginMigration()
plugin_migration.install_rag_pipeline_plugins(
input_file,
output_file,
workers,
)
click.echo(click.style("Installing rag pipeline plugins successfully", fg="green"))

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Annotated, Literal, Optional
from pydantic import (
AliasChoices,
@ -499,22 +499,6 @@ class UpdateConfig(BaseSettings):
)
class WorkflowVariableTruncationConfig(BaseSettings):
WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE: PositiveInt = Field(
# 100KB
1024_000,
description="Maximum size for variable to trigger final truncation.",
)
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
50000,
description="maximum length for string to trigger tuncation, measure in number of characters",
)
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
100,
description="maximum length for array to trigger truncation.",
)
class WorkflowConfig(BaseSettings):
"""
Configuration for workflow execution
@ -545,28 +529,6 @@ class WorkflowConfig(BaseSettings):
default=200 * 1024,
)
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
description="Maximum number of workers per GraphEngine instance",
default=10,
)
GRAPH_ENGINE_SCALE_UP_THRESHOLD: PositiveInt = Field(
description="Queue depth threshold that triggers worker scale up",
default=3,
)
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME: float = Field(
description="Seconds of idle time before scaling down workers",
default=5.0,
ge=0.1,
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""
@ -1014,18 +976,6 @@ class WorkflowLogConfig(BaseSettings):
)
class SwaggerUIConfig(BaseSettings):
SWAGGER_UI_ENABLED: bool = Field(
description="Whether to enable Swagger UI in api module",
default=True,
)
SWAGGER_UI_PATH: str = Field(
description="Swagger UI page path in api module",
default="/swagger-ui.html",
)
class FeatureConfig(
# place the configs in alphabet order
AppExecutionConfig,
@ -1057,12 +1007,10 @@ class FeatureConfig(
WorkspaceConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
WorkflowLogConfig,
WorkflowVariableTruncationConfig,
):
pass

View File

@ -222,28 +222,11 @@ class HostedFetchAppTemplateConfig(BaseSettings):
)
class HostedFetchPipelineTemplateConfig(BaseSettings):
"""
Configuration for fetching pipeline templates
"""
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="remote",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote pipeline templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig(
# place the configs in alphabet order
HostedAnthropicConfig,
HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig,
HostedFetchPipelineTemplateConfig,
HostedMinmaxConfig,
HostedOpenAiConfig,
HostedSparkConfig,

View File

@ -215,7 +215,6 @@ class DatabaseConfig(BaseSettings):
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
"pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
"pool_reset_on_return": None,
}
@ -300,7 +299,8 @@ class DatasetQueueMonitorConfig(BaseSettings):
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
CeleryConfig,
DatabaseConfig,
KeywordStoreConfig,
RedisConfig,
# configs of storage and storage providers

View File

@ -1,10 +1,9 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class ClickzettaConfig(BaseSettings):
class ClickzettaConfig(BaseModel):
"""
Clickzetta Lakehouse vector database configuration
"""

View File

@ -1,8 +1,7 @@
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
class MatrixoneConfig(BaseSettings):
class MatrixoneConfig(BaseModel):
"""Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")

View File

@ -1,6 +1,6 @@
from pydantic import Field
from configs.packaging.pyproject import PyProjectTomlConfig
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
class PackagingInfo(PyProjectTomlConfig):

View File

@ -4,9 +4,8 @@ import logging
import os
import threading
import time
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from .python_3x import http_request, makedirs_wrapper
from .utils import (
@ -26,13 +25,13 @@ logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url: str,
app_id: str,
cluster: str = "default",
secret: str = "",
start_hot_update: bool = True,
change_listener: Callable[[str, str, str, Any], None] | None = None,
_notification_map: dict[str, int] | None = None,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
):
# Core routing parameters
self.config_url = config_url
@ -48,17 +47,17 @@ class ApolloClient:
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache: dict[str, dict[str, Any]] = {}
self._no_key: dict[str, str] = {}
self._hash: dict[str, str] = {}
self._cache = {}
self._no_key = {}
self._hash = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread: threading.Thread | None = None
self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key: str | None = None
self.last_release_key = None
# Private startup method
self._path_checker()
if start_hot_update:
@ -69,7 +68,7 @@ class ApolloClient:
heartbeat.daemon = True
heartbeat.start()
def get_json_from_net(self, namespace: str = "application") -> dict[str, Any] | None:
def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
@ -89,7 +88,7 @@ class ApolloClient:
logger.exception("an error occurred in get_json_from_net")
return None
def get_value(self, key: str, default_val: Any = None, namespace: str = "application") -> Any:
def get_value(self, key, default_val=None, namespace="application"):
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
@ -105,8 +104,7 @@ class ApolloClient:
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
if namespace_data is not None:
self._update_cache_and_file(namespace_data, namespace)
self._update_cache_and_file(namespace_data, namespace)
return val
# read the file configuration
@ -128,23 +126,23 @@ class ApolloClient:
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace: str, key: str) -> None:
def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key
def _start_hot_update(self) -> None:
def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()
def stop(self) -> None:
def stop(self):
self._stopping = True
logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace: str, old_kv: dict[str, Any] | None, new_kv: dict[str, Any] | None) -> None:
def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None:
return
if old_kv is None:
@ -170,12 +168,12 @@ class ApolloClient:
except BaseException as e:
logger.warning(str(e))
def _path_checker(self) -> None:
def _path_checker(self):
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache
def _update_cache_and_file(self, namespace_data: dict[str, Any], namespace: str = "application") -> None:
def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
@ -189,7 +187,7 @@ class ApolloClient:
self._hash[namespace] = new_hash
# get the configuration from the local file
def _get_local_cache(self, namespace: str = "application") -> dict[str, Any]:
def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
@ -197,8 +195,8 @@ class ApolloClient:
return result
return {}
def _long_poll(self) -> None:
notifications: list[dict[str, Any]] = []
def _long_poll(self):
notifications = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
@ -238,7 +236,7 @@ class ApolloClient:
except Exception as e:
logger.warning(str(e))
def _get_net_and_set_local(self, namespace: str, n_id: int, call_change: bool = False) -> None:
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
@ -250,7 +248,7 @@ class ApolloClient:
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)
def _listener(self) -> None:
def _listener(self):
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
@ -268,13 +266,13 @@ class ApolloClient:
headers["Timestamp"] = time_unix_now
return headers
def _heart_beat(self) -> None:
def _heart_beat(self):
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace: str) -> None:
def _do_heart_beat(self, namespace):
url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}"
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
@ -294,7 +292,7 @@ class ApolloClient:
logger.exception("an error occurred in _do_heart_beat")
return None
def get_all_dicts(self, namespace: str) -> dict[str, Any] | None:
def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)

View File

@ -2,8 +2,6 @@ import logging
import os
import ssl
import urllib.request
from collections.abc import Mapping
from typing import Any
from urllib import parse
from urllib.error import HTTPError
@ -21,9 +19,9 @@ urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)
def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}) -> tuple[int, str | None]:
def http_request(url, timeout, headers={}):
try:
request = urllib.request.Request(url, headers=dict(headers))
request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
@ -35,9 +33,9 @@ def http_request(url: str, timeout: int | float, headers: Mapping[str, str] = {}
raise e
def url_encode(params: dict[str, Any]) -> str:
def url_encode(params):
return parse.urlencode(params)
def makedirs_wrapper(path: str) -> None:
def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True)

View File

@ -1,6 +1,5 @@
import hashlib
import socket
from typing import Any
from .python_3x import url_encode
@ -11,7 +10,7 @@ NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys
def signature(timestamp: str, uri: str, secret: str) -> str:
def signature(timestamp, uri, secret):
import base64
import hmac
@ -20,16 +19,16 @@ def signature(timestamp: str, uri: str, secret: str) -> str:
return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params: dict[str, Any]) -> str:
def url_encode_wrapper(params):
return url_encode(params)
def no_key_cache_key(namespace: str, key: str) -> str:
def no_key_cache_key(namespace, key):
return f"{namespace}{len(namespace)}{key}"
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any | None:
def get_value_from_dict(namespace_cache, key):
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
@ -39,7 +38,7 @@ def get_value_from_dict(namespace_cache: dict[str, Any] | None, key: str) -> Any
return None
def init_ip() -> str:
def init_ip():
ip = ""
s = None
try:

View File

@ -11,5 +11,5 @@ class RemoteSettingsSource:
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value

View File

@ -11,16 +11,16 @@ logger = logging.getLogger(__name__)
from configs.remote_settings_sources.base import RemoteSettingsSource
from .utils import parse_config
from .utils import _parse_config
class NacosSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.configs = configs
self.remote_configs: dict[str, str] = {}
self.remote_configs: dict[str, Any] = {}
self.async_init()
def async_init(self) -> None:
def async_init(self):
data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
@ -29,19 +29,22 @@ class NacosSettingsSource(RemoteSettingsSource):
try:
content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
self.remote_configs = self._parse_config(content)
except Exception:
except Exception as e:
logger.exception("[get-access-token] exception occurred")
raise
def _parse_config(self, content: str) -> dict[str, str]:
def _parse_config(self, content: str) -> dict:
if not content:
return {}
try:
return parse_config(content)
return _parse_config(self, content)
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
if field_value is None:
return None, field_name, False

View File

@ -17,26 +17,20 @@ class NacosHttpClient:
self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
self.token: str | None = None
self.token = None
self.token_ttl = 18000
self.token_expire_time: float = 0
def http_request(
self, url: str, method: str = "GET", headers: dict[str, str] | None = None, params: dict[str, str] | None = None
) -> str:
if headers is None:
headers = {}
if params is None:
params = {}
def http_request(self, url, method="GET", headers=None, params=None):
try:
self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status()
return response.text
except requests.RequestException as e:
except requests.exceptions.RequestException as e:
return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
def _inject_auth_info(self, headers, params, module="config"):
headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
if module == "login":
@ -51,17 +45,16 @@ class NacosHttpClient:
headers["timeStamp"] = ts
if self.username and self.password:
self.get_access_token(force_refresh=False)
if self.token is not None:
params["accessToken"] = self.token
params["accessToken"] = self.token
def __do_sign(self, sign_str: str, sk: str) -> str:
def __do_sign(self, sign_str, sk):
return (
base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
.decode()
.strip()
)
def get_sign_str(self, group: str, tenant: str, ts: str) -> str:
def get_sign_str(self, group, tenant, ts):
sign_str = ""
if tenant:
sign_str = tenant + "+"
@ -70,7 +63,7 @@ class NacosHttpClient:
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
return sign_str
def get_access_token(self, force_refresh: bool = False) -> str | None:
def get_access_token(self, force_refresh=False):
current_time = time.time()
if self.token and not force_refresh and self.token_expire_time > current_time:
return self.token
@ -84,7 +77,6 @@ class NacosHttpClient:
self.token = response_data.get("accessToken")
self.token_ttl = response_data.get("tokenTtl", 18000)
self.token_expire_time = current_time + self.token_ttl - 10
return self.token
except Exception:
except Exception as e:
logger.exception("[get-access-token] exception occur")
raise

View File

@ -1,4 +1,4 @@
def parse_config(content: str) -> dict[str, str]:
def _parse_config(self, content: str) -> dict[str, str]:
config: dict[str, str] = {}
if not content:
return config

View File

@ -19,7 +19,6 @@ language_timezone_mapping = {
"fa-IR": "Asia/Tehran",
"sl-SI": "Europe/Ljubljana",
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
}
languages = list(language_timezone_mapping.keys())

View File

@ -3,7 +3,6 @@ from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)

View File

@ -43,7 +43,7 @@ api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm"
api.add_resource(AppImportCheckDependenciesApi, "/apps/imports/<string:app_id>/check-dependencies")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, spec, version
from . import admin, apikey, extension, feature, ping, setup, version
# Import app controllers
from .app import (
@ -70,7 +70,7 @@ from .app import (
)
# Import auth controllers
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
# Import billing controllers
from .billing import billing, compliance
@ -84,17 +84,9 @@ from .datasets import (
external,
hit_testing,
metadata,
upload_file,
website,
)
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_draft_variable,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers
from .explore import (

View File

@ -1,6 +1,4 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, reqparse
@ -8,8 +6,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
P = ParamSpec("P")
R = TypeVar("R")
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
@ -18,9 +14,9 @@ from extensions.ext_database import db
from models.model import App, InstalledApp, RecommendedApp
def admin_required(view: Callable[P, R]):
def admin_required(view):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
@ -134,19 +130,15 @@ class InsertExploreAppApi(Resource):
app.is_public = False
with Session(db.engine) as session:
installed_apps = (
session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
installed_apps = session.execute(
select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
.scalars()
.all()
)
).all()
for installed_app in installed_apps:
session.delete(installed_app)
for installed_app in installed_apps:
db.session.delete(installed_app)
db.session.delete(recommended_app)
db.session.commit()

View File

@ -84,10 +84,10 @@ class BaseApiKeyListResource(Resource):
flask_restx.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
custom="max_keys_exceeded",
code="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix or "", 24)
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
setattr(api_token, self.resource_id_field, resource_id)
api_token.tenant_id = current_user.current_tenant_id

View File

@ -237,14 +237,9 @@ class AppExportApi(Resource):
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args")
parser.add_argument("workflow_id", type=str, location="args")
args = parser.parse_args()
return {
"data": AppDslService.export_dsl(
app_model=app_model, include_secret=args["include_secret"], workflow_id=args.get("workflow_id")
)
}
return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])}
class AppNameApi(Resource):

View File

@ -31,8 +31,6 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
logger = logging.getLogger(__name__)
class ChatMessageAudioApi(Resource):
@setup_required
@ -51,7 +49,7 @@ class ChatMessageAudioApi(Resource):
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
@ -72,7 +70,7 @@ class ChatMessageAudioApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("Failed to handle post request to ChatMessageAudioApi")
logging.exception("Failed to handle post request to ChatMessageAudioApi")
raise InternalServerError()
@ -99,7 +97,7 @@ class ChatMessageTextApi(Resource):
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
@ -120,7 +118,7 @@ class ChatMessageTextApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("Failed to handle post request to ChatMessageTextApi")
logging.exception("Failed to handle post request to ChatMessageTextApi")
raise InternalServerError()
@ -162,7 +160,7 @@ class TextModesApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("Failed to handle get request to TextModesApi")
logging.exception("Failed to handle get request to TextModesApi")
raise InternalServerError()

View File

@ -34,8 +34,6 @@ from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
# define completion message api for user
class CompletionMessageApi(Resource):
@ -69,7 +67,7 @@ class CompletionMessageApi(Resource):
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -82,7 +80,7 @@ class CompletionMessageApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -136,7 +134,7 @@ class ChatMessageApi(Resource):
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -151,7 +149,7 @@ class ChatMessageApi(Resource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -117,7 +117,7 @@ class CompletionConversationDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def delete(self, app_model, conversation_id):
if not current_user.is_editor:
raise Forbidden()

View File

@ -16,10 +16,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs.login import login_required
from models import App
from services.workflow_service import WorkflowService
class RuleGenerateApi(Resource):
@ -138,6 +135,9 @@ class InstructionGenerateApi(Resource):
try:
# Generate from nothing for a workflow node
if (args["current"] == code_template or args["current"] == "") and args["node_id"] != "":
from models import App, db
from services.workflow_service import WorkflowService
app = db.session.query(App).where(App.id == args["flow_id"]).first()
if not app:
return {"error": f"app {args['flow_id']} not found"}, 400
@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
def post(self) -> dict:
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args()

View File

@ -3,7 +3,6 @@ import logging
from flask_login import current_user
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from sqlalchemy import exists, select
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
@ -34,8 +33,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
logger = logging.getLogger(__name__)
class ChatMessageListApi(Resource):
message_infinite_scroll_pagination_fields = {
@ -95,22 +92,21 @@ class ChatMessageListApi(Resource):
.all()
)
# Initialize has_more based on whether we have a full page
has_more = False
if len(history_messages) == args["limit"]:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
select(
exists().where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
rest_count = (
db.session.query(Message)
.where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
)
.count()
)
else:
# If we don't have a full page, there are no more messages
has_more = False
if rest_count > 0:
has_more = True
history_messages = list(reversed(history_messages))
@ -130,7 +126,7 @@ class MessageFeedbackApi(Resource):
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
@ -219,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource):
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
return {"data": questions}

View File

@ -24,7 +24,6 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.helper.trace_id_helper import get_external_trace_id
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
@ -73,7 +72,6 @@ class DraftWorkflowApi(Resource):
Get draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
raise Forbidden()
@ -96,7 +94,6 @@ class DraftWorkflowApi(Resource):
Sync draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
raise Forbidden()
@ -174,7 +171,6 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
Run draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
assert isinstance(current_user, Account)
if not current_user.is_editor:
raise Forbidden()
@ -209,7 +205,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -222,12 +218,13 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
"""
Run draft workflow iteration node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -245,7 +242,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -259,11 +256,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
Run draft workflow iteration node
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -281,7 +279,7 @@ class WorkflowDraftRunIterationNodeApi(Resource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -294,13 +292,13 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -318,7 +316,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -331,13 +329,13 @@ class WorkflowDraftRunLoopNodeApi(Resource):
"""
Run draft workflow loop node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, location="json")
args = parser.parse_args()
@ -355,7 +353,7 @@ class WorkflowDraftRunLoopNodeApi(Resource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -368,13 +366,13 @@ class DraftWorkflowRunApi(Resource):
"""
Run draft workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
@ -407,19 +405,11 @@ class WorkflowTaskStopApi(Resource):
"""
Stop workflow task
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id)
return {"result": "success"}
@ -434,13 +424,13 @@ class DraftWorkflowNodeRunApi(Resource):
"""
Run draft workflow node
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
@ -482,9 +472,6 @@ class PublishedWorkflowApi(Resource):
"""
Get published workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@ -504,12 +491,13 @@ class PublishedWorkflowApi(Resource):
"""
Publish workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
@ -532,7 +520,7 @@ class PublishedWorkflowApi(Resource):
)
app_model.workflow_id = workflow.id
db.session.commit() # NOTE: this is necessary for update app_model.workflow_id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
@ -553,9 +541,6 @@ class DefaultBlockConfigsApi(Resource):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@ -574,12 +559,13 @@ class DefaultBlockConfigApi(Resource):
"""
Get default block config
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("q", type=str, location="args")
args = parser.parse_args()
@ -609,12 +595,13 @@ class ConvertToWorkflowApi(Resource):
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
if not isinstance(current_user, Account):
raise Forbidden()
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
if request.data:
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
@ -658,9 +645,6 @@ class PublishedAllWorkflowApi(Resource):
"""
Get published workflows
"""
if not isinstance(current_user, Account):
raise Forbidden()
if not current_user.is_editor:
raise Forbidden()
@ -709,12 +693,13 @@ class WorkflowByIdApi(Resource):
"""
Update workflow attributes
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("marked_name", type=str, required=False, location="json")
parser.add_argument("marked_comment", type=str, required=False, location="json")
@ -765,12 +750,13 @@ class WorkflowByIdApi(Resource):
"""
Delete workflow
"""
if not isinstance(current_user, Account):
raise Forbidden()
# Check permission
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
workflow_service = WorkflowService()
# Create a session and manage the transaction

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
@ -27,9 +27,7 @@ class WorkflowAppLogApi(Resource):
"""
parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args")
parser.add_argument(
"status", type=str, choices=["succeeded", "failed", "stopped", "partial-succeeded"], location="args"
)
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument(
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
)

View File

@ -1,5 +1,5 @@
import logging
from typing import NoReturn
from typing import Any, NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@ -13,17 +13,14 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode
from models.account import Account
from models import App, AppMode, db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@ -31,7 +28,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment):
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
@ -42,7 +39,7 @@ def _convert_values_to_json_serializable_object(value: Segment):
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable):
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
@ -76,22 +73,6 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
variable_file = variable.variable_file
assert variable_file is not None
return {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
@ -101,13 +82,11 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
"is_truncated": fields.Boolean(attribute=lambda model: model.file_id is not None),
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
full_content=fields.Raw(attribute=_serialize_full_content),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
@ -156,7 +135,6 @@ def _api_prerequisite(f):
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
assert isinstance(current_user, Account)
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -6,11 +6,9 @@ from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models import App, AppMode
from models.account import Account
def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account)
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@ -13,8 +13,6 @@ from libs.oauth_data_source import NotionOAuth
from ..wraps import account_initialization_required, setup_required
logger = logging.getLogger(__name__)
def get_oauth_providers():
with current_app.app_context():
@ -81,8 +79,8 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.HTTPError as e:
logger.exception(
except requests.exceptions.HTTPError as e:
logging.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
return {"error": "OAuth data source process failed"}, 400
@ -104,8 +102,8 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e:
logger.exception(
except requests.exceptions.HTTPError as e:
logging.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
)
return {"error": "OAuth data source process failed"}, 400

View File

@ -55,12 +55,6 @@ class EmailOrPasswordMismatchError(BaseHTTPException):
code = 400
class AuthenticationFailedError(BaseHTTPException):
error_code = "authentication_failed"
description = "Invalid email or password."
code = 401
class EmailPasswordLoginLimitError(BaseHTTPException):
error_code = "email_code_login_limit"
description = "Too many incorrect password attempts. Please try again later."

View File

@ -9,8 +9,8 @@ from configs import dify_config
from constants.languages import languages
from controllers.console import api
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
EmailOrPasswordMismatchError,
EmailPasswordLoginLimitError,
InvalidEmailError,
InvalidTokenError,
@ -79,7 +79,7 @@ class LoginApi(Resource):
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise AuthenticationFailedError()
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
@ -130,9 +130,8 @@ class ResetPasswordSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError:
except AccountRegisterError as are:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
@ -162,7 +161,7 @@ class EmailCodeLoginSendEmailApi(Resource):
language = "en-US"
try:
account = AccountService.get_user_through_email(args["email"])
except AccountRegisterError:
except AccountRegisterError as are:
raise AccountInFreezeError()
if account is None:
@ -200,7 +199,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args["token"])
try:
account = AccountService.get_user_through_email(user_email)
except AccountRegisterError:
except AccountRegisterError as are:
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@ -222,8 +221,8 @@ class EmailCodeLoginApi(Resource):
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
return NotAllowedCreateWorkspace()
except AccountRegisterError as are:
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()

View File

@ -24,8 +24,6 @@ from services.feature_service import FeatureService
from .. import api
logger = logging.getLogger(__name__)
def get_oauth_providers():
with current_app.app_context():
@ -80,9 +78,9 @@ class OAuthCallback(Resource):
try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e:
except requests.exceptions.RequestException as e:
error_text = e.response.text if e.response else str(e)
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):

View File

@ -1,202 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast
import flask_login
from flask import jsonify, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import Account
from models.model import OAuthProviderApp
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
from .. import api
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
parser = reqparse.RequestParser()
parser.add_argument("client_id", type=str, required=True, location="json")
parsed_args = parser.parse_args()
client_id = parsed_args.get("client_id")
if not client_id:
raise BadRequest("client_id is required")
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
if not oauth_provider_app:
raise NotFound("client_id is invalid")
return view(self, oauth_provider_app, *args, **kwargs)
return decorated
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
response = jsonify({"error": "Authorization header is required"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
parts = authorization_header.strip().split(None, 1)
if len(parts) != 2:
response = jsonify({"error": "Invalid Authorization header format"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
token_type = parts[0].strip()
if token_type.lower() != "bearer":
response = jsonify({"error": "token_type is invalid"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
access_token = parts[1].strip()
if not access_token:
response = jsonify({"error": "access_token is required"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
if not account:
response = jsonify({"error": "access_token or client_id is invalid"})
response.status_code = 401
response.headers["WWW-Authenticate"] = "Bearer"
return response
return view(self, oauth_provider_app, account, *args, **kwargs)
return decorated
class OAuthServerAppApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("redirect_uri", type=str, required=True, location="json")
parsed_args = parser.parse_args()
redirect_uri = parsed_args.get("redirect_uri")
# check if redirect_uri is valid
if redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
return jsonable_encoder(
{
"app_icon": oauth_provider_app.app_icon,
"app_label": oauth_provider_app.app_label,
"scope": oauth_provider_app.scope,
}
)
class OAuthServerUserAuthorizeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
account = cast(Account, flask_login.current_user)
user_account_id = account.id
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
return jsonable_encoder(
{
"code": code,
}
)
class OAuthServerUserTokenApi(Resource):
@setup_required
@oauth_server_client_id_required
def post(self, oauth_provider_app: OAuthProviderApp):
parser = reqparse.RequestParser()
parser.add_argument("grant_type", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=False, location="json")
parser.add_argument("client_secret", type=str, required=False, location="json")
parser.add_argument("redirect_uri", type=str, required=False, location="json")
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
raise BadRequest("code is required")
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not parsed_args["refresh_token"]:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
class OAuthServerUserAccountApi(Resource):
@setup_required
@oauth_server_client_id_required
@oauth_server_access_token_required
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
return jsonable_encoder(
{
"name": account.name,
"email": account.email,
"avatar": account.avatar,
"interface_language": account.interface_language,
"timezone": account.timezone,
}
)
api.add_resource(OAuthServerAppApi, "/oauth/provider")
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")

View File

@ -1,9 +1,9 @@
from flask_login import current_user
from flask_restx import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_user, login_required
from models.model import Account
from libs.login import login_required
from services.billing_service import BillingService
@ -17,10 +17,9 @@ class Subscription(Resource):
parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"])
parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"])
args = parser.parse_args()
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_subscription(
args["plan"], args["interval"], current_user.email, current_user.current_tenant_id
)
@ -32,9 +31,7 @@ class Invoices(Resource):
@account_initialization_required
@only_edition_cloud
def get(self):
assert isinstance(current_user, Account)
BillingService.is_tenant_owner_or_admin(current_user)
assert current_user.current_tenant_id is not None
return BillingService.get_invoices(current_user.email, current_user.current_tenant_id)

View File

@ -1,6 +1,4 @@
import json
from collections.abc import Generator
from typing import cast
from flask import request
from flask_login import current_user
@ -11,10 +9,7 @@ from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.datasource.entities.datasource_entities import DatasourceProviderType, OnlineDocumentPagesMessage
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
@ -23,7 +18,6 @@ from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
@ -118,18 +112,6 @@ class DataSourceNotionListApi(Resource):
@marshal_with(integrate_notion_info_list_fields)
def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
@ -153,49 +135,31 @@ class DataSourceNotionListApi(Resource):
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_user.current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
data_source_bindings = session.scalars(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
pages.append(page_info)
except Exception as e:
raise e
return {"notion_info": {**workspace_info, "pages": pages}}, 200
).all()
if not data_source_bindings:
return {"notion_info": []}, 200
pre_import_info_list = []
for data_source_binding in data_source_bindings:
source_info = data_source_binding.source_info
pages = source_info["pages"]
# Filter out already bound pages
for page in pages:
if page["page_id"] in exist_page_ids:
page["is_bound"] = True
else:
page["is_bound"] = False
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource):
@ -203,25 +167,27 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
def get(self, workspace_id, page_id, page_type):
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
workspace_id = str(workspace_id)
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding:
raise NotFound("Data source binding not found.")
extractor = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
notion_access_token=data_source_binding.access_token,
tenant_id=current_user.current_tenant_id,
)
@ -246,12 +212,10 @@ class DataSourceNotionApi(Resource):
extract_settings = []
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],

View File

@ -19,9 +19,9 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
@ -31,7 +31,6 @@ from fields.document_fields import document_status_fields
from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@ -280,15 +279,6 @@ class DatasetApi(Resource):
location="json",
help="Invalid external knowledge api id.",
)
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
args = parser.parse_args()
data = request.get_json()
@ -432,21 +422,17 @@ class DatasetIndexingEstimateApi(Resource):
if file_details:
for file_detail in file_details:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value,
upload_file=file_detail,
document_model=args["doc_form"],
datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
)
extract_settings.append(extract_setting)
elif args["info_list"]["data_source_type"] == "notion_import":
notion_info_list = args["info_list"]["notion_info_list"]
for notion_info in notion_info_list:
workspace_id = notion_info["workspace_id"]
credential_id = notion_info.get("credential_id")
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type="notion_import",
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
@ -459,7 +445,7 @@ class DatasetIndexingEstimateApi(Resource):
website_info_list = args["info_list"]["website_info_list"]
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type="website_crawl",
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
@ -567,7 +553,7 @@ class DatasetIndexingStatusApi(Resource):
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data, 200
return data
class DatasetApiKeyApi(Resource):

View File

@ -1,4 +1,3 @@
import json
import logging
from argparse import ArgumentTypeError
from typing import Literal, cast
@ -41,7 +40,6 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from fields.document_fields import (
@ -53,12 +51,9 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
logger = logging.getLogger(__name__)
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
@ -357,6 +352,9 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
@ -428,7 +426,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file, document_model=document.doc_form
datasource_type="upload_file", upload_file=file, document_model=document.doc_form
)
indexing_runner = IndexingRunner()
@ -470,15 +468,27 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
info_list = []
extract_settings = []
for document in documents:
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
# format document files info
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
info_list.append(file_id)
# format document notion info
elif (
data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info
):
pages = []
page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]}
pages.append(page)
notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages}
info_list.append(notion_info)
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
@ -490,17 +500,14 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise NotFound("File not found.")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE.value, upload_file=file_detail, document_model=document.doc_form
datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
datasource_type="notion_import",
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
@ -510,10 +517,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
datasource_type="website_crawl",
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
@ -656,7 +661,7 @@ class DocumentApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -961,7 +966,7 @@ class DocumentRetryApi(DocumentResource):
raise DocumentAlreadyFinishedError()
retry_documents.append(document)
except Exception:
logger.exception("Failed to retry document, document id: %s", document_id)
logging.exception("Failed to retry document, document id: %s", document_id)
continue
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
@ -1019,41 +1024,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
@ -1075,6 +1045,3 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@ -71,9 +71,3 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500
class PipelineNotFoundError(BaseHTTPException):
error_code = "pipeline_not_found"
description = "Pipeline not found."
code = 404

View File

@ -23,8 +23,6 @@ from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
logger = logging.getLogger(__name__)
class DatasetsHitTestingBase:
@staticmethod
@ -83,5 +81,5 @@ class DatasetsHitTestingBase:
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logger.exception("Hit testing failed.")
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))

View File

@ -1,362 +0,0 @@
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.oauth import OAuthHandler
from libs.helper import StrLen
from libs.login import login_required
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
oauth_config = DatasourceProviderService().get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider_id}")
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id,
tenant_id=tenant_id,
plugin_id=plugin_id,
provider=provider_name,
credential_id=credential_id,
)
oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_config,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class DatasourceOAuthCallback(Resource):
@setup_required
def get(self, provider_id: str):
context_id = request.cookies.get("context_id") or request.args.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
oauth_client_params = datasource_provider_service.get_oauth_client(
tenant_id=tenant_id,
datasource_provider_id=datasource_provider_id,
)
if not oauth_client_params:
raise NotFound()
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_handler = OAuthHandler()
oauth_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=datasource_provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
)
credential_id = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=context.get("credential_id"),
)
else:
datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=datasource_provider_id,
avatar_url=oauth_response.metadata.get("avatar_url") or None,
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None
)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
provider_id=datasource_provider_id,
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
class DatasourceAuthUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
credentials=args.get("credentials", {}),
name=args.get("name", None),
)
return {"result": "success"}, 201
class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(
tenant_id=current_user.current_tenant_id
)
return {"result": jsonable_encoder(datasources)}, 200
class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.setup_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
client_params=args.get("client_params", {}),
enabled=args.get("enable_oauth_custom_client", False),
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["id"],
)
return {"result": "success"}, 200
class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
)
return {"result": "success"}, 200
api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider_id>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<path:provider_id>",
)
api.add_resource(
DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
)
api.add_resource(
DatasourceAuthListApi,
"/auth/plugin/datasource/list",
)
api.add_resource(
DatasourceHardCodeAuthListApi,
"/auth/plugin/datasource/default-list",
)
api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)
api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)
api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

View File

@ -1,57 +0,0 @@
from flask_restx import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("credential_id", type=str, required=False, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
credential_id=args.get("credential_id"),
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@ -1,164 +0,0 @@
import logging
from flask import request
from flask_restx import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
knowledge_pipeline_publish_enabled,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@knowledge_pipeline_publish_enabled
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@ -1,114 +0,0 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restx import Resource, marshal, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(
name="",
description="",
icon_info=IconInfo(
icon="📙",
icon_background="#FFF4ED",
icon_type="emoji",
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
yaml_content=args["yaml_content"],
)
try:
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(
name="",
description="",
icon_info=IconInfo(
icon="📙",
icon_background="#FFF4ED",
icon_type="emoji",
),
permission=DatasetPermissionEnum.ONLY_ME,
partial_member_list=None,
),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@ -1,389 +0,0 @@
import logging
from typing import Any, NoReturn
from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models.account import Account
from models.dataset import Pipeline
from models.workflow import WorkflowDraftVariable
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class RagPipelineVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow_exist = rag_pipeline_service.is_workflow_exist(pipeline=pipeline)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=pipeline.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(pipeline.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class RagPipelineNodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(pipeline.id, node_id)
db.session.commit()
return Response("", 204)
class RagPipelineVariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
rag_pipeline_service = RagPipelineService()
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(pipeline.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(pipeline.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id)
return draft_vars
class RagPipelineSystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, pipeline: Pipeline):
return _get_variable_list(pipeline, SYSTEM_VARIABLE_NODE_ID)
class RagPipelineEnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, pipeline: Pipeline):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
RagPipelineVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables",
)
api.add_resource(
RagPipelineNodeVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/variables",
)
api.add_resource(
RagPipelineVariableApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>"
)
api.add_resource(
RagPipelineVariableResetApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset"
)
api.add_resource(
RagPipelineSystemVariableCollectionApi, "/rag/pipelines/<uuid:pipeline_id>/workflows/draft/system-variables"
)
api.add_resource(
RagPipelineEnvironmentVariableCollectionApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft/environment-variables",
)

View File

@ -1,147 +0,0 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
dataset_name=args.get("name"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@ -0,0 +1,62 @@
from flask_login import current_user
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from models.dataset import Dataset
from models.model import UploadFile
from services.dataset_service import DocumentService
class UploadFileApi(Resource):
@setup_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""Get upload file."""
# check dataset
dataset_id = str(dataset_id)
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == dataset_id)
.first()
)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check upload file
if document.data_source_type != "upload_file":
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
raise ValueError("Upload file id not found in document data source info.")
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": url,
"download_url": f"{url}&as_attachment=true",
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at.timestamp(),
}, 200
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")

View File

@ -1,47 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.account import Account
from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
if not isinstance(current_user, Account):
raise ValueError("current_user is not an account")
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)
if not pipeline:
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -26,8 +26,6 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
logger = logging.getLogger(__name__)
class ChatAudioApi(InstalledAppResource):
def post(self, installed_app):
@ -40,7 +38,7 @@ class ChatAudioApi(InstalledAppResource):
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
@ -61,7 +59,7 @@ class ChatAudioApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -85,7 +83,7 @@ class ChatTextApi(InstalledAppResource):
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
@ -106,5 +104,5 @@ class ChatTextApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -32,8 +32,6 @@ from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
# define completion api for user
class CompletionApi(InstalledAppResource):
@ -67,7 +65,7 @@ class CompletionApi(InstalledAppResource):
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -80,7 +78,7 @@ class CompletionApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -127,7 +125,7 @@ class ChatApi(InstalledAppResource):
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
logging.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -142,7 +140,7 @@ class ChatApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()

View File

@ -61,6 +61,7 @@ class ConversationApi(InstalledAppResource):
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}, 204

View File

@ -35,8 +35,6 @@ from services.errors.message import (
)
from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@ -128,7 +126,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -160,7 +158,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
return {"data": questions}

View File

@ -43,8 +43,6 @@ class ExploreAppMetaApi(InstalledAppResource):
def get(self, installed_app: InstalledApp):
"""Get app meta"""
app_model = installed_app.app
if not app_model:
raise ValueError("App not found")
return AppService().get_app_meta(app_model)

View File

@ -20,7 +20,6 @@ from core.errors.error import (
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user
from models.model import AppMode, InstalledApp
@ -36,8 +35,6 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
Run workflow
"""
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
@ -46,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
@ -66,7 +63,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource):
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
logging.exception("internal server error.")
raise InternalServerError()
@ -76,18 +73,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
Stop workflow task
"""
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}

View File

@ -1,6 +1,4 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate, Optional, ParamSpec, TypeVar
from flask_login import current_user
from flask_restx import Resource
@ -15,15 +13,19 @@ from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
def installed_app_required(view=None):
def decorator(view):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
if not kwargs.get("installed_app_id"):
raise ValueError("missing installed_app_id in path parameters")
installed_app_id = kwargs.get("installed_app_id")
installed_app_id = str(installed_app_id)
del kwargs["installed_app_id"]
installed_app = (
db.session.query(InstalledApp)
.where(
@ -50,10 +52,10 @@ def installed_app_required(view: Optional[Callable[Concatenate[InstalledApp, P],
return decorator
def user_allowed_to_access_app(view: Optional[Callable[Concatenate[InstalledApp, P], R]] = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
def user_allowed_to_access_app(view=None):
def decorator(view):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
def decorated(installed_app: InstalledApp, *args, **kwargs):
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id

View File

@ -20,7 +20,6 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
@ -69,7 +68,7 @@ class FileApi(Resource):
source = None
try:
upload_file = FileService(db.engine).upload_file(
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
@ -90,7 +89,7 @@ class FilePreviewApi(Resource):
@account_initialization_required
def get(self, file_id):
file_id = str(file_id)
text = FileService(db.engine).get_file_preview(file_id)
text = FileService.get_file_preview(file_id)
return {"content": text}

View File

@ -14,7 +14,6 @@ from controllers.common.errors import (
)
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account
from services.file_service import FileService
@ -62,7 +61,7 @@ class RemoteFileUploadApi(Resource):
try:
user = cast(Account, current_user)
upload_file = FileService(db.engine).upload_file(
upload_file = FileService.upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,

View File

@ -1,35 +0,0 @@
import logging
from flask_restx import Resource
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.schemas.schema_manager import SchemaManager
from libs.login import login_required
logger = logging.getLogger(__name__)
class SpecSchemaDefinitionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""
Get system JSON Schema definitions specification
Used for frontend component type mapping
"""
try:
schema_manager = SchemaManager()
schema_definitions = schema_manager.get_all_schema_definitions()
return schema_definitions, 200
except Exception:
logger.exception("Failed to get schema definitions from local registry")
# Return empty array as fallback
return [], 200
api.add_resource(SpecSchemaDefinitionsApi, "/spec/schema-definitions")

View File

@ -9,8 +9,6 @@ from configs import dify_config
from . import api
logger = logging.getLogger(__name__)
class VersionApi(Resource):
def get(self):
@ -36,7 +34,7 @@ class VersionApi(Resource):
try:
response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10))
except Exception as error:
logger.warning("Check update version error: %s.", str(error))
logging.warning("Check update version error: %s.", str(error))
result["version"] = args.get("current_version")
return result
@ -57,7 +55,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool:
# Compare versions
return latest > current
except version.InvalidVersion:
logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version)
return False

View File

@ -1,6 +1,4 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask_login import current_user
from sqlalchemy.orm import Session
@ -9,17 +7,14 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor(view: Callable[P, R]):
def interceptor(view):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
def decorated(*args, **kwargs):
user = current_user
tenant_id = user.current_tenant_id

View File

@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import Account, TenantAccountRole
from models.account import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
@ -15,12 +15,10 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
assert isinstance(current_user, Account)
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")
@ -66,12 +64,10 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
assert isinstance(current_user, Account)
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
assert tenant_id is not None
parser = reqparse.RequestParser()
parser.add_argument("model", type=str, required=True, nullable=False, location="json")

View File

@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource):
@cloud_edition_billing_resource_check("members")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("emails", type=list, required=True, location="json")
parser.add_argument("emails", type=str, required=True, location="json", action="append")
parser.add_argument("role", type=str, required=True, default="admin", location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()

View File

@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import StrLen, uuid_value
from libs.login import login_required
from services.billing_service import BillingService
from services.model_provider_service import ModelProviderService
@ -46,109 +45,12 @@ class ModelProviderCredentialApi(Resource):
@account_initialization_required
def get(self, provider: str):
tenant_id = current_user.current_tenant_id
# if credential_id is not provided, return current used credential
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
args = parser.parse_args()
model_provider_service = ModelProviderService()
credentials = model_provider_service.get_provider_credential(
tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id")
)
credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider)
return {"credentials": credentials}
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.create_provider_credential(
tenant_id=current_user.current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def put(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.update_provider_credential(
tenant_id=current_user.current_tenant_id,
provider=provider,
credentials=args["credentials"],
credential_id=args["credential_id"],
credential_name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credential(
tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"]
)
return {"result": "success"}, 204
class ModelProviderCredentialSwitchApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
service = ModelProviderService()
service.switch_active_provider_credential(
tenant_id=current_user.current_tenant_id,
provider=provider,
credential_id=args["credential_id"],
)
return {"result": "success"}
class ModelProviderValidateApi(Resource):
@setup_required
@ -167,7 +69,7 @@ class ModelProviderValidateApi(Resource):
error = ""
try:
model_provider_service.validate_provider_credentials(
model_provider_service.provider_credentials_validate(
tenant_id=tenant_id, provider=provider, credentials=args["credentials"]
)
except CredentialsValidateFailedError as ex:
@ -182,6 +84,42 @@ class ModelProviderValidateApi(Resource):
return response
class ModelProviderApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
model_provider_service = ModelProviderService()
try:
model_provider_service.save_provider_credentials(
tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"]
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def delete(self, provider: str):
if not current_user.is_admin_or_owner:
raise Forbidden()
model_provider_service = ModelProviderService()
model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider)
return {"result": "success"}, 204
class ModelProviderIconApi(Resource):
"""
Get model provider icon
@ -249,10 +187,8 @@ class ModelProviderPaymentCheckoutUrlApi(Resource):
api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers")
api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials")
api.add_resource(
ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch"
)
api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate")
api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>")
api.add_resource(
PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type"

Some files were not shown because too many files have changed in this diff Show More