Compare commits

..

27 Commits

Author SHA1 Message Date
eff8d08151 [autofix.ci] apply automated fixes 2026-04-18 14:44:39 +00:00
c98d1b7834 Merge branch 'main' into refactor/quota-service 2026-04-18 22:42:17 +08:00
00b6294798 chore: fix tests 2026-04-18 22:41:51 +08:00
fea4ba24e6 refactor: enhance trigger db session management 2026-04-18 11:26:35 +08:00
fd0c87225e Merge branch 'main' into build/billing-v3 2026-04-17 14:15:48 +08:00
48c6b520ad Merge branch 'main' into build/billing-v3 2026-04-13 22:59:28 +08:00
ae01a5d137 fix: unit test mock 2026-04-08 14:42:52 +08:00
ad6670ebcc fix: correct quota info response 2026-04-08 14:23:57 +08:00
8ca0917044 Merge branch 'main' into feat/new-biliing-quota 2026-04-08 13:39:24 +08:00
b3870524d4 fix usage get 2026-04-02 09:52:52 +08:00
c543188434 fix linter 2026-03-31 15:22:51 +08:00
f319a9e42f fix test case 2026-03-31 15:22:43 +08:00
58241a89a5 fix linter 2026-03-31 14:59:54 +08:00
422bf3506e rebuild quota service 2026-03-31 14:59:45 +08:00
6e745f9e9b fix linter 2026-03-31 09:49:24 +08:00
4e50d55339 fix comment 2026-03-31 09:49:09 +08:00
b95cdabe26 [autofix.ci] apply automated fixes 2026-03-30 08:45:37 +00:00
daa47c25bb Merge branch 'feat/new-biliing-quota' of github.com:langgenius/dify into feat/new-biliing-quota 2026-03-30 16:43:13 +08:00
f1bcd6d715 add test case for quota and billing service 2026-03-30 16:41:56 +08:00
8643ff43f5 Merge branch 'main' into feat/new-biliing-quota 2026-03-30 15:57:49 +08:00
c5f30a47f0 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-30 15:26:38 +08:00
37d438fa19 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-27 16:26:09 +08:00
9503803997 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-23 09:27:39 +08:00
d6476f5434 Merge remote-tracking branch 'origin/main' into feat/new-biliing-quota 2026-03-20 15:17:27 +08:00
80b4633e8f fix style check and test 2026-03-20 14:58:31 +08:00
3888969af3 [autofix.ci] apply automated fixes 2026-03-20 05:45:30 +00:00
658ac15589 use new quota system 2026-03-20 13:29:22 +08:00
1256 changed files with 26226 additions and 40518 deletions

View File

@ -367,7 +367,7 @@ For each extraction:
┌────────────────────────────────────────┐
│ 1. Extract code │
│ 2. Run: pnpm lint:fix │
│ 3. Run: pnpm type-check
│ 3. Run: pnpm type-check:tsgo
│ 4. Run: pnpm test │
│ 5. Test functionality manually │
│ 6. PASS? → Next extraction │

View File

@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
-**Import real project components** directly (including base components and siblings)
-**Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
-**DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
-**DO NOT mock** base components (`@/app/components/base/*`)
-**DO NOT mock** sibling/child components in the same directory
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
@ -325,12 +325,12 @@ For more detailed information, refer to:
### Reference Examples in Codebase
- `web/utils/classnames.spec.ts` - Utility function tests
- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
- `web/app/components/base/button/index.spec.tsx` - Component tests
- `web/__mocks__/provider-context.ts` - Mock factory example
### Project Configuration
- `web/vite.config.ts` - Vite/Vitest configuration
- `web/vitest.config.ts` - Vitest configuration
- `web/vitest.setup.ts` - Test environment setup
- `web/scripts/analyze-component.js` - Component analysis tool
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.

View File

@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Integration vs Mocking
- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
- [ ] Import real project components instead of mocking
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
- [ ] Prefer integration testing when using single spec file
@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
### Mocks
- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
- [ ] Shared mock state reset in `beforeEach`
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
@ -127,7 +127,7 @@ For the current file being tested:
- [ ] Run full directory test: `pnpm test path/to/directory/`
- [ ] Check coverage report: `pnpm test:coverage`
- [ ] Run `pnpm lint:fix` on all test files
- [ ] Run `pnpm type-check`
- [ ] Run `pnpm type-check:tsgo`
## Common Issues to Watch

View File

@ -2,27 +2,29 @@
## ⚠️ Important: What NOT to Mock
### DO NOT Mock Base Components or dify-ui Primitives
### DO NOT Mock Base Components
**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
**Never mock components from `@/app/components/base/`** such as:
- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
- `Loading`, `Spinner`
- `Button`, `Input`, `Select`
- `Tooltip`, `Modal`, `Dropdown`
- `Icon`, `Badge`, `Tag`
**Why?**
- These components have their own dedicated tests
- Base components will have their own dedicated tests
- Mocking them creates false positives (tests pass but real integration fails)
- Using real components tests actual integration behavior
```typescript
// ❌ WRONG: Don't mock base components or dify-ui primitives
// ❌ WRONG: Don't mock base components
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => <button>{children}</button> }))
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
// ✅ CORRECT: Import and use the real components
// ✅ CORRECT: Import and use real base components
import Loading from '@/app/components/base/loading'
import { Button } from '@langgenius/dify-ui/button'
import Button from '@/app/components/base/button'
// They will render normally in tests
```
@ -317,7 +319,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ✅ DO
1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
1. **Use real base components** - Import from `@/app/components/base/` directly
1. **Use real project components** - Prefer importing over mocking
1. **Use real Zustand stores** - Set test state via `store.setState()`
1. **Reset mocks in `beforeEach`**, not `afterEach`
@ -328,7 +330,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
### ❌ DON'T
1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
1. Don't mock components you can import directly
1. Don't create overly simplified mocks that miss conditional logic
@ -340,7 +342,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
```
Need to use a component in test?
├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
├─ Is it from @/app/components/base/*?
│ └─ YES → Import real component, DO NOT mock
├─ Is it a project component?

3
.github/CODEOWNERS vendored
View File

@ -6,9 +6,6 @@
* @crazywoola @laipz8200 @Yeuoly
# ESLint suppression file is maintained by autofix.ci pruning.
/eslint-suppressions.json
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola

View File

@ -4,7 +4,7 @@ runs:
using: composite
steps:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
node-version-file: .nvmrc
cache: true

1
.github/labeler.yml vendored
View File

@ -6,4 +6,5 @@ web:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'

19
.github/workflows/anti-slop.yml vendored Normal file
View File

@ -0,0 +1,19 @@
name: Anti-Slop PR Check
on:
pull_request_target:
types: [opened, edited, synchronize]
permissions:
pull-requests: write
contents: read
jobs:
anti-slop:
runs-on: ubuntu-latest
steps:
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
close-pr: false
failure-add-pr-labels: "needs-revision"

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
api-unit:
name: API Unit Tests
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
env:
COVERAGE_FILE: coverage-unit
defaults:
@ -35,7 +35,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -62,7 +62,7 @@ jobs:
api-integration:
name: API Integration Tests
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
env:
COVERAGE_FILE: coverage-integration
STORAGE_TYPE: opendal
@ -84,7 +84,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -105,7 +105,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -137,7 +137,7 @@ jobs:
api-coverage:
name: API Coverage
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
needs:
- api-unit
- api-integration
@ -156,7 +156,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -13,7 +13,7 @@ permissions:
jobs:
autofix:
if: github.repository == 'langgenius/dify'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'
@ -25,7 +25,7 @@ jobs:
- name: Check Docker Compose inputs
if: github.event_name != 'merge_group'
id: docker-compose-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
docker/generate_docker_compose
@ -35,7 +35,7 @@ jobs:
- name: Check web inputs
if: github.event_name != 'merge_group'
id: web-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
@ -43,11 +43,12 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
api/**
@ -57,7 +58,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
@ -122,4 +123,4 @@ jobs:
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
- if: github.event_name != 'merge_group'
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3

View File

@ -26,9 +26,6 @@ jobs:
build:
runs-on: ${{ matrix.runs_on }}
if: github.repository == 'langgenius/dify'
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
@ -38,28 +35,28 @@ jobs:
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-latest
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-24.04-arm
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-latest
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-24.04-arm
steps:
- name: Prepare
@ -73,8 +70,8 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Extract metadata for Docker
id: meta
@ -84,15 +81,16 @@ jobs:
- name: Build Docker image
id: build
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
- name: Export digest
env:
@ -110,33 +108,9 @@ jobs:
if-no-files-found: error
retention-days: 1
fork-build-validate:
if: github.repository != 'langgenius/dify'
runs-on: ubuntu-24.04
strategy:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Validate Docker image
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
with:
push: false
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
platforms: linux/amd64
create-manifest:
needs: build
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
if: github.repository == 'langgenius/dify'
strategy:
matrix:

View File

@ -9,7 +9,7 @@ concurrency:
jobs:
db-migration-test-postgres:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Checkout code
@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"
@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -59,7 +59,7 @@ jobs:
run: uv run --directory api flask upgrade-db
db-migration-test-mysql:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Checkout code
@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"
@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@ -110,28 +110,6 @@ jobs:
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
# to return (container processes started); it does not wait on healthcheck
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
# wait the migration runs while InnoDB is still initialising and gets
# killed with "Lost connection during query". Poll a real SELECT until it
# succeeds.
- name: Wait for MySQL to accept queries
run: |
set +e
for i in $(seq 1 60); do
if docker run --rm --network host mysql:8.0 \
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
-e 'SELECT 1' >/dev/null 2>&1; then
echo "MySQL ready after ${i}s"
exit 0
fi
sleep 1
done
echo "MySQL not ready after 60s; dumping container logs:"
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
exit 1
- name: Run DB Migration
env:
DEBUG: true

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/agent-dev'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/dev'

View File

@ -13,7 +13,7 @@ on:
jobs:
deploy:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/enterprise'

View File

@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'

View File

@ -14,59 +14,28 @@ concurrency:
jobs:
build-docker:
if: github.event.pull_request.head.repo.full_name == github.repository
runs-on: ${{ matrix.runs_on }}
permissions:
contents: read
id-token: write
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-latest
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
runs_on: ubuntu-latest
context: "{{defaultContext}}"
file: "web/Dockerfile"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
- name: Set up Depot CLI
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
- name: Build Docker Image
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
with:
project: ${{ vars.DEPOT_PROJECT_ID }}
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: ${{ matrix.platform }}
build-docker-fork:
if: github.event.pull_request.head.repo.full_name != github.repository
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}"
file: "web/Dockerfile"
steps:
@ -79,4 +48,6 @@ jobs:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
platforms: linux/amd64
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -7,7 +7,7 @@ jobs:
permissions:
contents: read
pull-requests: write
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:

View File

@ -23,7 +23,7 @@ concurrency:
jobs:
pre_job:
name: Skip Duplicate Checks
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
outputs:
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
steps:
@ -39,7 +39,7 @@ jobs:
name: Check Changed Files
needs: pre_job
if: needs.pre_job.outputs.should_skip != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
outputs:
api-changed: ${{ steps.changes.outputs.api }}
e2e-changed: ${{ steps.changes.outputs.e2e }}
@ -69,6 +69,7 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@ -82,6 +83,7 @@ jobs:
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
@ -139,7 +141,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Report skipped API tests
run: echo "No API-related changes detected; skipping API tests."
@ -152,7 +154,7 @@ jobs:
- check-changes
- api-tests-run
- api-tests-skip
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Finalize API Tests status
env:
@ -199,7 +201,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Report skipped web tests
run: echo "No web-related changes detected; skipping web tests."
@ -212,7 +214,7 @@ jobs:
- check-changes
- web-tests-run
- web-tests-skip
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Finalize Web Tests status
env:
@ -258,7 +260,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Report skipped web full-stack e2e
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
@ -271,7 +273,7 @@ jobs:
- check-changes
- web-e2e-run
- web-e2e-skip
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Finalize Web Full-Stack E2E status
env:
@ -323,7 +325,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Report skipped VDB tests
run: echo "No VDB-related changes detected; skipping VDB tests."
@ -336,7 +338,7 @@ jobs:
- check-changes
- vdb-tests-run
- vdb-tests-skip
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Finalize VDB Tests status
env:
@ -382,7 +384,7 @@ jobs:
- pre_job
- check-changes
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Report skipped DB migration tests
run: echo "No migration-related changes detected; skipping DB migration tests."
@ -395,7 +397,7 @@ jobs:
- check-changes
- db-migration-test-run
- db-migration-test-skip
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Finalize DB Migration Test status
env:

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with pyrefly diff
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
permissions:
actions: read
contents: read

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-diff:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true

View File

@ -12,7 +12,7 @@ permissions: {}
jobs:
comment:
name: Comment PR with type coverage
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Setup Python & UV
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true

View File

@ -10,7 +10,7 @@ permissions:
jobs:
pyrefly-type-coverage:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
permissions:
contents: read
issues: write
@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true

View File

@ -16,7 +16,7 @@ jobs:
name: Validate PR title
permissions:
pull-requests: read
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Complete merge group check
if: github.event_name == 'merge_group'

View File

@ -12,7 +12,7 @@ on:
jobs:
stale:
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write

View File

@ -15,7 +15,7 @@ permissions:
jobs:
python-style:
name: Python Style
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Checkout code
@ -25,7 +25,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
api/**
@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: false
python-version: "3.12"
@ -57,7 +57,7 @@ jobs:
web-style:
name: Web Style
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./web
@ -73,7 +73,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
@ -83,6 +83,7 @@ jobs:
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@ -94,7 +95,7 @@ jobs:
- name: Restore ESLint cache
if: steps.changed-files.outputs.any_changed == 'true'
id: eslint-cache-restore
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: .eslintcache
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
@ -123,14 +124,14 @@ jobs:
- name: Save ESLint cache
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: .eslintcache
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
superlinter:
name: SuperLinter
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
steps:
- name: Checkout code
@ -141,7 +142,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
**.sh

View File

@ -9,6 +9,7 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}
@ -17,7 +18,7 @@ concurrency:
jobs:
build:
name: unit test for Node.js SDK
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
defaults:
run:
@ -29,7 +30,7 @@ jobs:
persist-credentials: false
- name: Use Node.js
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: 22
cache: ''

View File

@ -35,7 +35,7 @@ concurrency:
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
timeout-minutes: 120
steps:
@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@ef50f123a3a9be95b60040d042717517407c7256 # v1.0.110
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
timeout-minutes: 5
steps:

View File

@ -16,7 +16,7 @@ jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -65,7 +65,7 @@ jobs:
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: VDB Smoke Tests
runs-on: depot-ubuntu-24.04
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -62,7 +62,7 @@ jobs:
# tiflash
- name: Set up Vector Stores for Smoke Coverage
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@ -13,7 +13,7 @@ concurrency:
jobs:
test:
name: Web Full-Stack E2E
runs-on: depot-ubuntu-24.04-4
runs-on: ubuntu-latest
defaults:
run:
shell: bash
@ -28,7 +28,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Setup UV and Python
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -16,7 +16,7 @@ concurrency:
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
runs-on: depot-ubuntu-24.04-4
runs-on: ubuntu-latest
env:
VITEST_COVERAGE_SCOPE: app-components
strategy:
@ -54,7 +54,7 @@ jobs:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: depot-ubuntu-24.04-4
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:
@ -92,7 +92,7 @@ jobs:
dify-ui-test:
name: dify-ui Tests
runs-on: depot-ubuntu-24.04-4
runs-on: ubuntu-latest
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
defaults:

4
.gitignore vendored
View File

@ -237,10 +237,6 @@ scripts/stress-test/reports/
.playwright-mcp/
.serena/
# vitest browser mode attachments (failure screenshots, traces, etc.)
.vitest-attachments/
**/__screenshots__/
# settings
*.local.json
*.local.md

1
.npmrc Normal file
View File

@ -0,0 +1 @@
save-exact=true

View File

@ -30,7 +30,7 @@ The codebase is split into:
## Language Style
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
## General Practices

View File

@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
```bash
# In your .env file
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
SUGGESTED_QUESTIONS_MAX_TOKENS=512
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
```
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
### Metrics Monitoring with Grafana
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
@ -147,7 +160,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
### Deployment with Kubernetes
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)

View File

@ -659,11 +659,6 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
@ -714,6 +709,22 @@ SWAGGER_UI_PATH=/swagger-ui.html
# Set to false to export dataset IDs as plain text for easier cross-environment import
DSL_EXPORT_ENCRYPT_DATASET_ID=true
# Suggested Questions After Answer Configuration
# These environment variables allow customization of the suggested questions feature
#
# Custom prompt for generating suggested questions (optional)
# If not set, uses the default prompt that generates 3 questions under 20 characters each
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
# SUGGESTED_QUESTIONS_PROMPT=
# Maximum number of tokens for suggested questions generation (default: 256)
# Adjust this value for longer questions or more questions
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
# Temperature for suggested questions generation (default: 0.0)
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
# SUGGESTED_QUESTIONS_TEMPERATURE=0
# Tenant isolated task queue configuration
TENANT_ISOLATED_TASK_CONCURRENCY=1

View File

@ -101,11 +101,3 @@ The scripts resolve paths relative to their location, so you can run them from a
uv run ruff format ./ # Format code
uv run basedpyright . # Type checking
```
## Generate TS stub
```
uv run dev/generate_swagger_specs.py --output-dir openapi
```
use https://jsontotable.org/openapi-to-typescript to convert to typescript

View File

@ -113,18 +113,8 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
# Validates name encoding for non-Latin characters.
name = name.strip().encode("utf-8").decode("utf-8") if name else None
# Generate a random password that satisfies the password policy.
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
for _ in range(100):
new_password = secrets.token_urlsafe(16)
try:
valid_password(new_password)
break
except Exception:
continue
else:
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
return
# generate random password
new_password = secrets.token_urlsafe(16)
# register account
account = RegisterService.register(

View File

@ -11,7 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.tools.utils.system_encryption import encrypt_system_params
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from extensions.ext_database import db
from models import Tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_params(client_params_dict)
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_params(client_params_dict)
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@ -287,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for Creators Platform integration
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable Creators Platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client ID for Creators Platform integration",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@ -1400,7 +1379,6 @@ class FeatureConfig(
AuthConfig, # Changed from OAuthConfig to AuthConfig
BillingConfig,
CodeExecutionSandboxConfig,
CreatorsPlatformConfig,
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,

View File

@ -1,6 +0,0 @@
from pydantic import BaseModel, JsonValue
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict[str, JsonValue]
action: str

View File

@ -129,7 +129,6 @@ class AppNamePayload(BaseModel):
class AppIconPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon data")
icon_type: IconType | None = Field(default=None, description="Icon type")
icon_background: str | None = Field(default=None, description="Icon background color")
@ -692,32 +691,6 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=None)
@edit_permission_required
def post(self, app_model):
"""Publish app to Creators Platform"""
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"error": "Creators Platform features are not enabled"}, 403
current_user, _ = current_account_with_tenant()
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
dsl_bytes = dsl_content.encode("utf-8")
claim_code = upload_dsl(dsl_bytes)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")
@ -756,12 +729,7 @@ class AppIconApi(Resource):
args = AppIconPayload.model_validate(console_ns.payload or {})
app_service = AppService()
app_model = app_service.update_app_icon(
app_model,
args.icon or "",
args.icon_background or "",
args.icon_type,
)
app_model = app_service.update_app_icon(app_model, args.icon or "", args.icon_background or "")
response_model = AppDetail.model_validate(app_model, from_attributes=True)
return response_model.model_dump(mode="json")

View File

@ -50,7 +50,6 @@ from fields.dataset_fields import (
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermission, DatasetPermissionEnum
from models.enums import ApiTokenType, SegmentStatus
@ -890,8 +889,7 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
base = dify_config.SERVICE_API_URL or request.host_url.rstrip("/")
return {"api_base_url": normalize_api_base_url(base)}
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
@console_ns.route("/datasets/retrieval-setting")

View File

@ -38,48 +38,6 @@ class HitTestingPayload(BaseModel):
class DatasetsHitTestingBase:
@staticmethod
def _normalize_hit_testing_query(query: Any) -> str:
"""Return the user-visible query string from legacy and current response shapes."""
if isinstance(query, str):
return query
if isinstance(query, dict):
content = query.get("content")
if isinstance(content, str):
return content
raise ValueError("Invalid hit testing query response")
@staticmethod
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
"""Coerce nullable collection fields into lists before response validation."""
if not isinstance(records, list):
return []
normalized_records: list[dict[str, Any]] = []
for record in records:
if not isinstance(record, dict):
continue
normalized_record = dict(record)
segment = normalized_record.get("segment")
if isinstance(segment, dict):
normalized_segment = dict(segment)
if normalized_segment.get("keywords") is None:
normalized_segment["keywords"] = []
normalized_record["segment"] = normalized_segment
if normalized_record.get("child_chunks") is None:
normalized_record["child_chunks"] = []
if normalized_record.get("files") is None:
normalized_record["files"] = []
normalized_records.append(normalized_record)
return normalized_records
@staticmethod
def get_and_validate_dataset(dataset_id: str):
assert isinstance(current_user, Account)
@ -117,12 +75,7 @@ class DatasetsHitTestingBase:
attachment_ids=args.get("attachment_ids"),
limit=10,
)
return {
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
marshal(response.get("records", []), hit_testing_record_fields)
),
}
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:

View File

@ -8,10 +8,10 @@ from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@ -34,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -51,11 +56,6 @@ class ConsoleHumanInputFormApi(Resource):
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@staticmethod
def _ensure_console_recipient_type(form: Form) -> None:
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
raise NotFoundError("form not found")
@setup_required
@login_required
@account_initialization_required
@ -99,8 +99,10 @@ class ConsoleHumanInputFormApi(Resource):
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
self._ensure_console_recipient_type(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."

View File

@ -37,11 +37,6 @@ class TagBindingRemovePayload(BaseModel):
type: TagType = Field(description="Tag type")
class TagBindingItemDeletePayload(BaseModel):
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
@ -75,7 +70,6 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagBindingItemDeletePayload,
TagListQueryParam,
TagResponse,
)
@ -158,107 +152,41 @@ class TagUpdateDeleteApi(Resource):
return "", 204
def _require_tag_binding_edit_permission() -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(
tag_ids=payload.tag_ids,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
def _remove_tag_binding() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=payload.tag_id,
target_id=payload.target_id,
type=payload.type,
)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings")
class TagBindingCollectionApi(Resource):
"""Canonical collection resource for tag binding creation."""
@console_ns.doc("create_tag_binding")
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@console_ns.route("/tag-bindings/<uuid:id>")
class TagBindingItemApi(Resource):
"""Canonical item resource for tag binding deletion."""
@console_ns.doc("delete_tag_binding")
@console_ns.doc(params={"id": "Tag ID"})
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
_require_tag_binding_edit_permission()
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(
tag_id=str(id),
target_id=payload.target_id,
type=payload.type,
)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200
@console_ns.route("/tag-bindings/create")
class DeprecatedTagBindingCreateApi(Resource):
"""Deprecated verb-based alias for tag binding creation."""
@console_ns.doc("create_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@console_ns.route("/tag-bindings/remove")
class DeprecatedTagBindingRemoveApi(Resource):
"""Deprecated verb-based alias for tag binding deletion."""
@console_ns.doc("delete_tag_binding_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
class TagBindingDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_binding()
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
return {"result": "success"}, 200

View File

@ -595,25 +595,13 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
# Default to the initial phase; any legacy/unexpected client input is
# coerced back to `old_email` so we never trust the caller to declare
# later phases without a verified predecessor token.
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
# The token used to request a new-email code must come from the
# old-email verification step. This prevents the bypass described
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@ -632,7 +620,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=send_phase,
phase=args.phase,
)
return {"result": "success", "data": token}
@ -667,31 +655,12 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Only advance tokens that were minted by the matching send-code step;
# refuse tokens that have already progressed or lack a phase marker so
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
# is strictly enforced.
phase_transitions = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token that carries the
# upgraded phase so later steps can check it.
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(user_email)
@ -721,29 +690,13 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
# Only tokens that completed both verification phases may be used to
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
# the initial send-code step could be replayed directly here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
# Bind the new email to the token that was mailed and verified, so a
# verified token cannot be reused with a different `new_email` value.
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
# Revoke only after all checks pass so failed attempts don't burn a
# legitimately verified token.
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@ -1,11 +1,3 @@
"""Console workspace endpoint controllers.
This module exposes workspace-scoped plugin endpoint management APIs. The
canonical write routes follow resource-oriented paths, while the historical
verb-based aliases stay available as deprecated resources so OpenAPI metadata
marks only the legacy paths as deprecated.
"""
from typing import Any
from flask import request
@ -33,12 +25,7 @@ class EndpointIdPayload(BaseModel):
endpoint_id: str
class EndpointUpdatePayload(BaseModel):
settings: dict[str, Any]
name: str = Field(min_length=1)
class LegacyEndpointUpdatePayload(EndpointIdPayload):
class EndpointUpdatePayload(EndpointIdPayload):
settings: dict[str, Any]
name: str = Field(min_length=1)
@ -89,7 +76,6 @@ register_schema_models(
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
LegacyEndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
@ -102,60 +88,8 @@ register_schema_models(
)
def _create_endpoint() -> dict[str, bool]:
"""Create a plugin endpoint for the current workspace."""
user, tenant_id = current_account_with_tenant()
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Update a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
name=args.name,
settings=args.settings,
)
}
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
"""Delete a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=endpoint_id,
)
}
@console_ns.route("/workspaces/current/endpoints")
class EndpointCollectionApi(Resource):
"""Canonical collection resource for endpoint creation."""
@console_ns.route("/workspaces/current/endpoints/create")
class EndpointCreateApi(Resource):
@console_ns.doc("create_endpoint")
@console_ns.doc(description="Create a new plugin endpoint")
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@ -170,33 +104,22 @@ class EndpointCollectionApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
user, tenant_id = current_account_with_tenant()
args = EndpointCreatePayload.model_validate(console_ns.payload)
@console_ns.route("/workspaces/current/endpoints/create")
class DeprecatedEndpointCreateApi(Resource):
"""Deprecated verb-based alias for endpoint creation."""
@console_ns.doc("create_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
)
)
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.models[EndpointCreateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
)
}
except PluginPermissionDeniedError as e:
raise ValueError(e.description) from e
@console_ns.route("/workspaces/current/endpoints/list")
@ -267,56 +190,10 @@ class EndpointListForSinglePluginApi(Resource):
)
@console_ns.route("/workspaces/current/endpoints/<string:id>")
class EndpointItemApi(Resource):
"""Canonical item resource for endpoint updates and deletion."""
@console_ns.route("/workspaces/current/endpoints/delete")
class EndpointDeleteApi(Resource):
@console_ns.doc("delete_endpoint")
@console_ns.doc(description="Delete a plugin endpoint")
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.models[EndpointDeleteResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, id: str):
return _delete_endpoint(endpoint_id=id)
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.models[EndpointUpdateResponse.__name__],
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
def patch(self, id: str):
return _update_endpoint(endpoint_id=id)
@console_ns.route("/workspaces/current/endpoints/delete")
class DeprecatedEndpointDeleteApi(Resource):
"""Deprecated verb-based alias for endpoint deletion."""
@console_ns.doc("delete_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for deleting a plugin endpoint. "
"Use DELETE /workspaces/current/endpoints/{id} instead."
)
)
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
@console_ns.response(
200,
@ -329,23 +206,22 @@ class DeprecatedEndpointDeleteApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
args = EndpointIdPayload.model_validate(console_ns.payload)
return _delete_endpoint(endpoint_id=args.endpoint_id)
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
)
}
@console_ns.route("/workspaces/current/endpoints/update")
class DeprecatedEndpointUpdateApi(Resource):
"""Deprecated verb-based alias for endpoint updates."""
@console_ns.doc("update_endpoint_deprecated")
@console_ns.doc(deprecated=True)
@console_ns.doc(
description=(
"Deprecated legacy alias for updating a plugin endpoint. "
"Use PATCH /workspaces/current/endpoints/{id} instead."
)
)
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
class EndpointUpdateApi(Resource):
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
@console_ns.response(
200,
"Endpoint updated successfully",
@ -357,8 +233,19 @@ class DeprecatedEndpointUpdateApi(Resource):
@is_admin_or_owner_required
@account_initialization_required
def post(self):
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
return _update_endpoint(endpoint_id=args.endpoint_id)
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
endpoint_id=args.endpoint_id,
name=args.name,
settings=args.settings,
)
}
@console_ns.route("/workspaces/current/endpoints/enable")

View File

@ -1131,14 +1131,6 @@ class ToolMCPAuthApi(Resource):
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
parsed = urlparse(server_url)
sanitized_url = f"{parsed.scheme}://{parsed.hostname}{parsed.path}"
logger.warning(
"MCP authorization failed for provider %s (url=%s)",
provider_id,
sanitized_url,
exc_info=True,
)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -23,11 +23,9 @@ from .app import (
conversation,
file,
file_preview,
human_input_form,
message,
site,
workflow,
workflow_events,
)
from .dataset import (
dataset,
@ -52,7 +50,6 @@ __all__ = [
"file",
"file_preview",
"hit_testing",
"human_input_form",
"index",
"message",
"metadata",
@ -61,7 +58,6 @@ __all__ = [
"segment",
"site",
"workflow",
"workflow_events",
]
api.add_namespace(service_api_ns)

View File

@ -1,137 +0,0 @@
"""
Service API human input form endpoints.
This module exposes app-token authenticated APIs for fetching and submitting
paused human input forms in workflow/chatflow runs.
"""
import json
import logging
from datetime import datetime
from flask import Response
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form) -> Response:
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
raise NotFound("Form not found")
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
# Keep app-token callers scoped to the public web-form surface; internal HITL
# routes must continue to flow through console-only authentication.
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
raise NotFound("Form not found")
@service_api_ns.route("/form/human_input/<string:form_token>")
class WorkflowHumanInputFormApi(Resource):
@service_api_ns.doc("get_human_input_form")
@service_api_ns.doc(description="Get a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form retrieved successfully",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token
def get(self, app_model: App, form_token: str):
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
service.ensure_form_active(form)
return _jsonify_form_definition(form)
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
@service_api_ns.doc("submit_human_input_form")
@service_api_ns.doc(description="Submit a paused human input form by token")
@service_api_ns.doc(params={"form_token": "Human input form token"})
@service_api_ns.doc(
responses={
200: "Form submitted successfully",
400: "Bad request - invalid submission data",
401: "Unauthorized - invalid API token",
404: "Form not found",
412: "Form already submitted or expired",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, form_token: str):
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFound("Form not found")
_ensure_form_belongs_to_app(form, app_model)
_ensure_form_is_allowed_for_service_api(form)
recipient_type = form.recipient_type
if recipient_type is None:
logger.warning("Recipient type is None for form, form_id=%s", form.id)
raise BadRequest("Form recipient type is invalid")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=end_user.id,
)
except FormNotFoundError:
raise NotFound("Form not found")
return {}, 200

View File

@ -1,142 +0,0 @@
"""
Service API workflow resume event stream endpoints.
"""
import json
from collections.abc import Generator
from flask import Response, request
from flask_restx import Resource
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.task_entities import StreamEvent
from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@service_api_ns.route("/workflow/<string:task_id>/events")
class WorkflowEventsApi(Resource):
"""Service API for getting workflow execution events after resume."""
@service_api_ns.doc("get_workflow_events")
@service_api_ns.doc(description="Get workflow execution events stream after resume")
@service_api_ns.doc(
params={
"task_id": "Workflow run ID",
"user": "End user identifier (query param)",
"include_state_snapshot": (
"Whether to replay from persisted state snapshot, "
'specify `"true"` to include a status snapshot of executed nodes'
),
"continue_on_pause": (
"Whether to keep the stream open across workflow_paused events,"
'specify `"true"` to keep the stream open for `workflow_paused` events.'
),
}
)
@service_api_ns.doc(
responses={
200: "SSE event stream",
401: "Unauthorized - invalid API token",
404: "Workflow run not found",
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, task_id: str):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise NotWorkflowAppError()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=task_id,
)
if workflow_run is None:
raise NotFound("Workflow run not found")
if workflow_run.app_id != app_model.id:
raise NotFound("Workflow run not found")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFound("Workflow run not found")
if workflow_run.created_by != end_user.id:
raise NotFound("Workflow run not found")
workflow_run_entity = workflow_run
if workflow_run_entity.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run_entity.id,
workflow_run=workflow_run_entity,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise NotWorkflowAppError()
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run_entity,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
human_input_surface=HumanInputSurface.SERVICE_API,
close_on_pause=not continue_on_pause,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(
app_mode,
workflow_run_entity.id,
terminal_events=terminal_events,
),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

View File

@ -1,12 +1,4 @@
"""Service API endpoints for dataset document management.
The canonical Service API paths use hyphenated route segments. Legacy underscore
aliases remain registered for backward compatibility, but they must stay marked
deprecated in generated API docs so clients migrate toward the canonical paths.
"""
import json
from collections.abc import Mapping
from contextlib import ExitStack
from typing import Self
from uuid import UUID
@ -125,137 +117,12 @@ register_schema_models(
)
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
"""Create a document from text for both canonical and legacy routes."""
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
dataset_id_str = str(dataset_id)
tenant_id_str = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id_str,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
if not current_user:
raise ValueError("current_user is required")
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
"""Update a document from text for both canonical and legacy routes."""
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
)
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args.get("text"):
text = args.get("text")
name = args.get("name")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/document/create_by_text",
"/datasets/<uuid:dataset_id>/document/create-by-text",
)
class DocumentAddByTextApi(DatasetApiResource):
"""Resource for the canonical text document creation route."""
"""Resource for documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text")
@ -271,43 +138,81 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
"""Create document by text."""
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
"""Deprecated resource alias for text document creation."""
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
@service_api_ns.doc("create_document_by_text_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for creating a new document by providing text content. "
"Use /datasets/{dataset_id}/document/create-by-text instead."
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
@service_api_ns.doc(
responses={
200: "Document created successfully",
401: "Unauthorized - invalid API token",
400: "Bad request - invalid parameters",
if not dataset:
raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
embedding_model_provider = payload.embedding_model_provider
embedding_model = payload.embedding_model
if embedding_model_provider and embedding_model:
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID):
"""Create document by text through the deprecated underscore alias."""
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
args["data_source"] = data_source
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
if not current_user:
raise ValueError("current_user is required")
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
)
class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for the canonical text document update route."""
"""Resource for update documents."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text")
@ -324,35 +229,62 @@ class DocumentUpdateByTextApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text."""
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
"""Deprecated resource alias for text document updates."""
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
@service_api_ns.doc("update_document_by_text_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for updating an existing document by providing text content. "
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
)
)
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
401: "Unauthorized - invalid API token",
404: "Document not found",
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by text through the deprecated underscore alias."""
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
args = payload.model_dump(exclude_none=True)
if not dataset:
raise ValueError("Dataset does not exist.")
retrieval_model = payload.retrieval_model
if (
retrieval_model
and retrieval_model.reranking_model
and retrieval_model.reranking_model.reranking_provider_name
and retrieval_model.reranking_model.reranking_model_name
):
DatasetService.check_reranking_model_setting(
tenant_id,
retrieval_model.reranking_model.reranking_provider_name,
retrieval_model.reranking_model.reranking_model_name,
)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if args.get("text"):
text = args.get("text")
name = args.get("name")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
return documents_and_batch_fields, 200
@service_api_ns.route(
@ -468,98 +400,15 @@ class DocumentAddByFileApi(DatasetApiResource):
return documents_and_batch_fields, 200
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
"""Update a document from an uploaded file for canonical and deprecated routes."""
dataset_id_str = str(dataset_id)
tenant_id_str = str(tenant_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args: dict[str, object] = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
@service_api_ns.route(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
"""Deprecated resource aliases for file document updates."""
class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
@service_api_ns.doc("update_document_by_file_deprecated")
@service_api_ns.doc(deprecated=True)
@service_api_ns.doc(
description=(
"Deprecated legacy alias for updating an existing document by uploading a file. "
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
)
)
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
@ -570,9 +419,82 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file through the deprecated file-update aliases."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
args = {}
if "data" in request.form:
args = json.loads(request.form["data"])
if "doc_form" not in args:
args["doc_form"] = dataset.chunk_structure or "text_model"
if "doc_language" not in args:
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
if "file" in request.files:
# save file info
file = request.files["file"]
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
documents, _ = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
document = documents[0]
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
return documents_and_batch_fields, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
@ -886,22 +808,6 @@ class DocumentApi(DatasetApiResource):
return response
@service_api_ns.doc("update_document_by_file")
@service_api_ns.doc(description="Update an existing document by uploading a file")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@service_api_ns.doc(
responses={
200: "Document updated successfully",
401: "Unauthorized - invalid API token",
404: "Document not found",
}
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
"""Update document by file on the canonical document resource."""
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
@service_api_ns.doc("delete_document")
@service_api_ns.doc(description="Delete a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})

View File

@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
@ -26,6 +26,11 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,

View File

@ -1,7 +1,5 @@
from typing import Any
CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
@ -22,11 +20,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""
Validate and set defaults for suggested questions feature.
Optional fields:
- prompt: custom instruction prompt.
- model: provider/model configuration for suggested question generation.
Validate and set defaults for suggested questions feature
:param config: app model config args
"""
@ -45,27 +39,4 @@ class SuggestedQuestionsAfterAnswerConfigManager:
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
prompt = config["suggested_questions_after_answer"].get("prompt")
if prompt is not None and not isinstance(prompt, str):
raise ValueError("prompt in suggested_questions_after_answer must be of string type")
if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH:
raise ValueError(
f"prompt in suggested_questions_after_answer must be less than or equal to "
f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters"
)
if "model" in config["suggested_questions_after_answer"]:
model_config = config["suggested_questions_after_answer"]["model"]
if not isinstance(model_config, dict):
raise ValueError("model in suggested_questions_after_answer must be of object type")
if "provider" not in model_config or not isinstance(model_config["provider"], str):
raise ValueError("provider in suggested_questions_after_answer.model must be of string type")
if "name" not in model_config or not isinstance(model_config["name"], str):
raise ValueError("name in suggested_questions_after_answer.model must be of string type")
if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict):
raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type")
return config, ["suggested_questions_after_answer"]

View File

@ -34,11 +34,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
)
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.ops.ops_trace_manager import TraceQueueManager
@ -659,11 +655,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> (
ChatbotAppBlockingResponse
| AdvancedChatPausedBlockingResponse
| Generator[ChatbotAppStreamResponse, None, None]
):
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -175,7 +175,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Create a variable pool.
# init variable pool
variable_pool = VariablePool.from_bootstrap()
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(

View File

@ -3,7 +3,7 @@ from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
AppBlockingResponse,
AppStreamResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
@ -12,40 +12,22 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
StreamEvent,
)
class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
paused_data = blocking_response.data.model_dump(mode="json")
return {
"event": StreamEvent.WORKFLOW_PAUSED.value,
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
"workflow_run_id": blocking_response.data.workflow_run_id,
"data": paused_data,
}
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
"event": StreamEvent.MESSAGE.value,
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
@ -59,9 +41,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,8 +50,7 @@ class AdvancedChatAppGenerateResponseConverter(
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get("metadata", {})
if isinstance(metadata, dict):
response["metadata"] = cls._get_simple_metadata(metadata)
response["metadata"] = cls._get_simple_metadata(metadata)
return response

View File

@ -53,18 +53,14 @@ from core.app.entities.queue_entities import (
WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
AdvancedChatPausedBlockingResponse,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
PingStreamResponse,
StreamResponse,
WorkflowPauseStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -214,13 +210,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(
self,
) -> Union[
ChatbotAppBlockingResponse,
AdvancedChatPausedBlockingResponse,
Generator[ChatbotAppStreamResponse, None, None],
]:
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@ -236,39 +226,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
"""
Process blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return AdvancedChatPausedBlockingResponse(
task_id=stream_response.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=stream_response.data.workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
status=stream_response.data.status,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
),
)
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
@ -289,41 +254,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> AdvancedChatPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return AdvancedChatPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
data=AdvancedChatPausedBlockingResponse.Data(
id=self._message_id,
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
answer=self._task_state.answer,
metadata=self._message_end_to_stream_response().metadata,
created_at=self._message_created_at,
paused_nodes=paused_nodes,
reasons=reasons,
status=WorkflowExecutionStatus.PAUSED,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:

View File

@ -1,8 +1,6 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -14,9 +12,11 @@ from core.app.entities.task_entities import (
)
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -1,9 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from pydantic import JsonValue
from typing import Any, Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@ -13,10 +11,8 @@ from graphon.model_runtime.errors.invoke import InvokeError
logger = logging.getLogger(__name__)
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
return cast(TBlockingResponse, response)
class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@classmethod
def convert(
@ -24,7 +20,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -33,7 +29,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
@ -43,12 +39,12 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
@classmethod
@abstractmethod
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
raise NotImplementedError
@classmethod
@ -110,13 +106,13 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
error_responses: dict[type[Exception], dict[str, Any]] = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -130,7 +126,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
}
# Determine the response based on the type of exception
data: dict[str, JsonValue] | None = None
data: dict[str, Any] | None = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -1,8 +1,6 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -14,9 +12,11 @@ from core.app.entities.task_entities import (
)
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,

View File

@ -52,7 +52,6 @@ from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.human_input_forms import load_form_tokens_by_form_id
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
@ -337,26 +336,7 @@ class WorkflowResponseConverter:
except (TypeError, json.JSONDecodeError):
definition_payload = {}
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
form_token_by_form_id = load_form_tokens_by_form_id(
human_input_form_ids,
session=session,
surface=(
HumanInputSurface.SERVICE_API
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
else None
),
)
# Reconnect paths must preserve the same pause-reason contract as live streams;
# otherwise clients see schema drift after resume.
pause_reasons = enrich_human_input_pause_reasons(
pause_reasons,
form_tokens_by_form_id=form_token_by_form_id,
expiration_times_by_form_id={
form_id: int(expiration_time.timestamp())
for form_id, expiration_time in expiration_times_by_form_id.items()
},
)
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
responses: list[StreamResponse] = []

View File

@ -1,8 +1,6 @@
from collections.abc import Generator
from typing import Any, cast
from pydantic import JsonValue
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
@ -14,15 +12,17 @@ from core.app.entities.task_entities import (
)
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
response: dict[str, Any] = {
response = {
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
yield "ping"
continue
response_chunk: dict[str, JsonValue] = {
response_chunk = {
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,

View File

@ -1,7 +1,6 @@
from collections.abc import Callable, Generator, Iterable, Mapping
from collections.abc import Callable, Generator, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.task_entities import StreamEvent
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
@ -27,7 +26,6 @@ class MessageGenerator:
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
@ -35,5 +33,4 @@ class MessageGenerator:
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
terminal_events=terminal_events,
)

View File

@ -13,9 +13,11 @@ from core.app.entities.task_entities import (
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@ -24,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -27,11 +27,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.datasource.entities.datasource_entities import (
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
@ -631,11 +627,7 @@ class PipelineGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -144,7 +144,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
)
variable_pool = VariablePool.from_bootstrap()
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(

View File

@ -59,7 +59,7 @@ def stream_topic_events(
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if terminal_events is None:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:

View File

@ -25,11 +25,7 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
@ -616,11 +612,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Account | EndUser,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> (
WorkflowAppBlockingResponse
| WorkflowAppPausedBlockingResponse
| Generator[WorkflowAppStreamResponse, None, None]
):
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -106,7 +106,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool.from_bootstrap()
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(

View File

@ -9,29 +9,24 @@ from core.app.entities.task_entities import (
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.model_dump())
return blocking_response.model_dump()
@classmethod
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@ -42,15 +42,12 @@ from core.app.entities.queue_entities import (
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
HumanInputRequiredPauseReasonPayload,
HumanInputRequiredResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppPausedBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
@ -121,11 +118,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(
self,
) -> Union[
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
]:
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
@ -136,24 +129,19 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
else:
return self._to_blocking_response(generator)
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
"""
human_input_responses: list[HumanInputRequiredResponse] = []
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, HumanInputRequiredResponse):
human_input_responses.append(stream_response)
elif isinstance(stream_response, WorkflowPauseStreamResponse):
return WorkflowAppPausedBlockingResponse(
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
@ -164,13 +152,12 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
paused_nodes=stream_response.data.paused_nodes,
reasons=stream_response.data.reasons,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
return WorkflowAppBlockingResponse(
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
@ -187,44 +174,12 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
),
)
return response
else:
continue
if human_input_responses:
return self._build_paused_blocking_response_from_human_input(human_input_responses)
raise ValueError("queue listening stopped unexpectedly.")
def _build_paused_blocking_response_from_human_input(
self, human_input_responses: list[HumanInputRequiredResponse]
) -> WorkflowAppPausedBlockingResponse:
runtime_state = self._resolve_graph_runtime_state()
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
created_at = int(runtime_state.start_at)
reasons = [
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
for response in human_input_responses
]
return WorkflowAppPausedBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=human_input_responses[-1].workflow_run_id,
data=WorkflowAppPausedBlockingResponse.Data(
id=human_input_responses[-1].workflow_run_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PAUSED,
outputs={},
error=None,
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
total_tokens=runtime_state.total_tokens,
total_steps=runtime_state.node_run_steps,
created_at=created_at,
finished_at=None,
paused_nodes=paused_nodes,
reasons=reasons,
),
)
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:

View File

@ -188,7 +188,7 @@ class WorkflowBasedAppRunner:
ValueError: If neither single_iteration_run nor single_loop_run is specified
"""
# Create initial runtime state with variable pool containing environment variables
variable_pool = VariablePool.from_bootstrap()
variable_pool = VariablePool()
add_variables_to_pool(
variable_pool,
build_bootstrap_variables(

View File

@ -1,13 +1,12 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, JsonValue
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from graphon.entities import WorkflowStartReason
from graphon.entities.pause_reason import PauseReasonType
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from graphon.nodes.human_input.entities import FormInput, UserAction
@ -296,40 +295,6 @@ class HumanInputRequiredResponse(StreamResponse):
data: Data
class HumanInputRequiredPauseReasonPayload(BaseModel):
"""
Public pause-reason payload used by blocking responses when only
``human_input_required`` events are available.
"""
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
@classmethod
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
return cls(
form_id=data.form_id,
node_id=data.node_id,
node_title=data.node_title,
form_content=data.form_content,
inputs=data.inputs,
actions=data.actions,
display_in_ui=data.display_in_ui,
form_token=data.form_token,
resolved_default_values=data.resolved_default_values,
expiration_time=data.expiration_time,
)
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
@ -390,7 +355,7 @@ class NodeStartStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
@ -447,7 +412,7 @@ class NodeFinishStreamResponse(StreamResponse):
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
@ -809,34 +774,6 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
data: Data
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
"""
ChatbotAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
message_id: str
workflow_run_id: str
answer: str
metadata: Mapping[str, object] = Field(default_factory=dict)
created_at: int
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
status: WorkflowExecutionStatus
elapsed_time: float
total_tokens: int
total_steps: int
data: Data
class CompletionAppBlockingResponse(AppBlockingResponse):
"""
CompletionAppBlockingResponse entity
@ -882,33 +819,6 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
data: Data
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
"""
WorkflowAppPausedBlockingResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
workflow_id: str
status: WorkflowExecutionStatus
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
total_tokens: int
total_steps: int
created_at: int
finished_at: int | None
paused_nodes: Sequence[str] = Field(default_factory=list)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from collections.abc import Generator # Changed from Iterator
from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
@contextmanager
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
token = _current_file_access_scope.set(scope)
try:
yield

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -15,21 +14,8 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str
provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__(
self,
@ -44,12 +30,8 @@ class DifyCredentialsProvider:
user_id=run_context.user_id,
)
self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
@ -64,7 +46,6 @@ class DifyCredentialsProvider:
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials
@ -84,8 +65,7 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
),
enable_credentials_cache=True,
)
)
self.model_manager = model_manager
@ -104,7 +84,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id,
user_id=run_context.user_id,
)
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
model_manager = ModelManager(provider_manager=provider_manager)
return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -16,7 +16,6 @@ from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol
if TYPE_CHECKING:
from graphon.nodes.llm.node import LLMNode
@ -117,8 +116,7 @@ class LLMQuotaLayer(GraphEngineLayer):
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
model_instance = cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
typed_node: QuestionClassifierNode = cast("QuestionClassifierNode", node)
model_instance = cast(PreparedLLMProtocol, typed_node._model_instance)
model_instance = cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:

View File

@ -24,7 +24,6 @@ from core.entities.provider_entities import (
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from graphon.model_runtime import ModelRuntime
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
ConfigurateMethod,
@ -34,6 +33,7 @@ from graphon.model_runtime.entities.provider_entities import (
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from graphon.model_runtime.runtime import ModelRuntime
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.enums import CredentialSourceType
@ -109,7 +109,7 @@ class ProviderConfiguration(BaseModel):
def get_model_provider_factory(self) -> ModelProviderFactory:
"""Return a provider factory that preserves any request-bound runtime."""
if self._bound_model_runtime is not None:
return ModelProviderFactory(runtime=self._bound_model_runtime)
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
@ -1392,12 +1392,10 @@ class ProviderConfiguration(BaseModel):
:param model_type: model type
:return:
"""
from core.plugin.impl.model_runtime_factory import create_model_type_instance
model_provider_factory = self.get_model_provider_factory()
return create_model_type_instance(
factory=model_provider_factory, provider=self.provider.provider, model_type=model_type
)
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None

View File

@ -1,41 +0,0 @@
"""
Helper module for Creators Platform integration.
Provides functionality to upload DSL files to the Creators Platform
and generate redirect URLs with OAuth authorization codes.
"""
import logging
from urllib.parse import urlencode
import httpx
from yarl import URL
from configs import dify_config
logger = logging.getLogger(__name__)
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
response.raise_for_status()
data = response.json()
claim_code = data.get("data", {}).get("claim_code")
if not claim_code:
raise ValueError("Creators Platform did not return a valid claim_code")
return claim_code
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
params: dict[str, str] = {"dsl_claim_code": claim_code}
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
if client_id:
from services.oauth_server import OAuthServerService
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
params["oauth_code"] = oauth_code
return f"{base_url}?{urlencode(params)}"

View File

@ -4,7 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.plugin.impl.model_runtime_factory import create_model_type_instance, create_plugin_model_provider_factory
from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory
from extensions.ext_hosting_provider import hosting_configuration
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeBadRequestError
@ -44,8 +44,8 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id)
# Get model instance of LLM
model_type_instance = create_model_type_instance(
factory=model_provider_factory, provider=openai_provider_name, model_type=ModelType.MODERATION
model_type_instance = model_provider_factory.get_model_type_instance(
provider=openai_provider_name, model_type=ModelType.MODERATION
)
model_type_instance = cast(ModerationModel, model_type_instance)
moderation_result = model_type_instance.invoke(

View File

@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Any, NotRequired, Protocol, TypedDict, cast
from typing import Any, Protocol, TypedDict, cast
import json_repair
from sqlalchemy import select
@ -18,6 +18,8 @@ from core.llm_generator.prompts import (
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SUGGESTED_QUESTIONS_MAX_TOKENS,
SUGGESTED_QUESTIONS_TEMPERATURE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
@ -39,36 +41,6 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
class SuggestedQuestionsModelConfig(TypedDict):
provider: str
name: str
completion_params: NotRequired[dict[str, object]]
def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]:
"""
Normalize raw completion params into invocation parameters and stop sequences.
This mirrors the app-model access path by separating ``stop`` from provider
parameters before invocation, then drops non-positive token limits because
some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their
provider-specific output-token field.
"""
normalized_parameters = dict(completion_params)
stop_value = normalized_parameters.pop("stop", [])
if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value):
stop = stop_value
else:
stop = []
for token_limit_key in ("max_tokens", "max_output_tokens"):
token_limit = normalized_parameters.get(token_limit_key)
if isinstance(token_limit, int | float) and token_limit <= 0:
normalized_parameters.pop(token_limit_key, None)
return normalized_parameters, stop
class WorkflowServiceInterface(Protocol):
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
pass
@ -151,15 +123,8 @@ class LLMGenerator:
return name
@classmethod
def generate_suggested_questions_after_answer(
cls,
tenant_id: str,
histories: str,
*,
instruction_prompt: str | None = None,
model_config: object | None = None,
) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt)
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
@ -168,36 +133,10 @@ class LLMGenerator:
try:
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {}
provider = configured_model.get("provider")
model_name = configured_model.get("name")
use_configured_model = False
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
try:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=provider,
model=model_name,
)
use_configured_model = True
except Exception:
logger.warning(
"Failed to use configured suggested-questions model %s/%s, fallback to default model",
provider,
model_name,
exc_info=True,
)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
except InvokeAuthorizationError:
return []
@ -206,29 +145,19 @@ class LLMGenerator:
questions: Sequence[str] = []
try:
configured_completion_params = configured_model.get("completion_params")
if use_configured_model and isinstance(configured_completion_params, dict):
model_parameters, stop = _normalize_completion_params(configured_completion_params)
elif use_configured_model:
model_parameters = {}
stop = []
else:
# Default-model generation keeps the built-in suggested-questions tuning.
model_parameters = {
"max_tokens": 2560,
"temperature": 0.0,
}
stop = []
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=model_parameters,
stop=stop,
model_parameters={
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
},
stream=False,
)
text_content = response.message.get_text_content()
questions = output_parser.parse(text_content) if text_content else []
except InvokeError:
questions = []
except Exception:
logger.exception("Failed to generate suggested questions after answer")
questions = []

View File

@ -3,28 +3,17 @@ import logging
import re
from collections.abc import Sequence
from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
logger = logging.getLogger(__name__)
class SuggestedQuestionsAfterAnswerOutputParser:
def __init__(self, instruction_prompt: str | None = None) -> None:
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
@staticmethod
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
if not instruction_prompt or not instruction_prompt.strip():
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
def get_format_instructions(self) -> str:
return self._instruction_prompt
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str) -> Sequence[str]:
stripped_text = text.strip()
action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL)
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
questions: list[str] = []
if action_match is not None:
try:
@ -34,6 +23,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
else:
if isinstance(json_obj, list):
questions = [question for question in json_obj if isinstance(question, str)]
elif stripped_text:
logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200])
return questions

View File

@ -1,4 +1,5 @@
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
import os
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the users input into two parts: “Intention” and “Subject”.
@ -95,8 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
)
# Default prompt and model parameters for suggested questions.
DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
# Default prompt for suggested questions (can be overridden by environment variable)
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keep each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
@ -104,6 +105,15 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
'["question1","question2","question3"]\n'
)
# Environment variable override for suggested questions prompt
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
)
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
GENERATOR_QA_PROMPT = (
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
" in the long text. Please think step by step."

View File

@ -303,16 +303,9 @@ class StreamableHTTPTransport:
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
error_msg = (
f"MCP server URL returned 404 Not Found: {self.url} "
"— verify the server URL is correct and the server is running"
if is_initialization
else "Session terminated by server"
)
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
message=error_msg,
)
return
@ -388,13 +381,12 @@ class StreamableHTTPTransport:
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
message: str = "Session terminated by server",
):
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message=message),
error=ErrorData(code=32600, message="Session terminated by server"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)

View File

@ -1,6 +1,5 @@
import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config
@ -37,13 +36,11 @@ class ModelInstance:
Model instance class.
"""
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
@ -437,30 +434,8 @@ class ModelInstance:
class ModelManager:
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
def __init__(self, provider_manager: ProviderManager):
self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -488,19 +463,8 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type
)
cred_cache_key = (tenant_id, provider, model_type.value, model)
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
model_instance = ModelInstance(provider_model_bundle, model)
return model_instance
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
"""

View File

@ -4,7 +4,7 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Literal, Union, overload
from typing import IO, Any, Union
from pydantic import ValidationError
from redis import RedisError
@ -14,18 +14,13 @@ from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from extensions.ext_redis import redis_client
from graphon.model_runtime import ModelRuntime
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult
from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult
from graphon.model_runtime.runtime import ModelRuntime
from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
@ -200,34 +195,6 @@ class PluginModelRuntime(ModelRuntime):
return schema
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...
@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
def invoke_llm(
self,
*,
@ -255,50 +222,6 @@ class PluginModelRuntime(ModelRuntime):
stream=stream,
)
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
# TODO: added to pass type check.
# it is a new method from upstream that is not invoked at all.
raise NotImplementedError
def get_llm_num_tokens(
self,
*,

View File

@ -3,14 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
from graphon.model_runtime.model_providers.base.moderation_model import ModerationModel
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import Speech2TextModel
from graphon.model_runtime.model_providers.base.text_embedding_model import TextEmbeddingModel
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
if TYPE_CHECKING:
@ -18,15 +10,6 @@ if TYPE_CHECKING:
from core.plugin.impl.model_runtime import PluginModelRuntime
from core.provider_manager import ProviderManager
_MODEL_TYPE_CLASS_MAP: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}
class PluginModelAssembly:
"""Compose request-scoped model views on top of a single plugin runtime."""
@ -55,7 +38,7 @@ class PluginModelAssembly:
@property
def model_provider_factory(self) -> ModelProviderFactory:
if self._model_provider_factory is None:
self._model_provider_factory = ModelProviderFactory(runtime=self.model_runtime)
self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime)
return self._model_provider_factory
@property
@ -104,30 +87,3 @@ def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None
def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager:
"""Create a tenant-bound model manager for service flows."""
return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager
def create_model_type_instance(
factory: ModelProviderFactory,
provider: str,
model_type: ModelType,
) -> AIModel:
"""Instantiate the AIModel subclass for *model_type* backed by *factory*'s runtime.
This replaces ``ModelProviderFactory.get_model_type_instance`` which was
removed in graphon 0.3.0. The mapping from ModelType to concrete AIModel
subclass is maintained here so that callers do not need to know the
subclass constructors.
:param factory: factory whose ``runtime`` and provider resolution are used.
:param provider: provider identifier (canonical or short name).
:param model_type: the model type to instantiate.
:returns: an AIModel subclass instance wired to the factory's runtime.
:raises ValueError: if *model_type* is not supported.
"""
model_class = _MODEL_TYPE_CLASS_MAP.get(model_type)
if model_class is None:
msg = f"Unsupported model type: {model_type}"
raise ValueError(msg)
provider_entity = factory.get_model_provider(provider)
return model_class(provider_schema=provider_entity, model_runtime=factory.runtime)

View File

@ -151,12 +151,6 @@ def deserialize_response(raw_data: bytes) -> Response:
response = Response(response=body, status=status_code)
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
# parsed ones so we faithfully reproduce the original response. Use Headers.add
# rather than dict-style assignment so that repeated headers such as Set-Cookie
# (and any other multi-valued header per RFC 9110) are preserved instead of
# being overwritten.
response.headers.clear()
for line in lines[1:]:
if not line:
continue
@ -164,6 +158,6 @@ def deserialize_response(raw_data: bytes) -> Response:
if ":" not in line_str:
continue
name, value = line_str.split(":", 1)
response.headers.add(name, value.strip())
response.headers[name] = value.strip()
return response

View File

@ -56,7 +56,7 @@ from models.provider_ids import ModelProviderID
from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime import ModelRuntime
from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
@ -70,32 +70,12 @@ class ProviderManager:
Request-bound managers may carry caller identity in that runtime, and the
resulting ``ProviderConfiguration`` objects must reuse it for downstream
model-type and schema lookups.
Configuration assembly is cached per manager instance so call chains that
share one request-scoped manager can reuse the same provider graph instead
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
when a long-lived manager needs to observe writes performed within the same
instance scope.
"""
decoding_rsa_key: Any | None
decoding_cipher_rsa: Any | None
_model_runtime: ModelRuntime
_configurations_cache: dict[str, ProviderConfigurations]
def __init__(self, model_runtime: ModelRuntime):
self.decoding_rsa_key = None
self.decoding_cipher_rsa = None
self._model_runtime = model_runtime
self._configurations_cache = {}
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
"""Drop assembled provider configurations cached on this manager instance."""
if tenant_id is None:
self._configurations_cache.clear()
return
self._configurations_cache.pop(tenant_id, None)
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
"""
@ -134,10 +114,6 @@ class ProviderManager:
:param tenant_id:
:return:
"""
cached_configurations = self._configurations_cache.get(tenant_id)
if cached_configurations is not None:
return cached_configurations
# Get all provider records of the workspace
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
@ -165,7 +141,7 @@ class ProviderManager:
)
# Get all provider entities
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_entities = model_provider_factory.get_providers()
# Get All preferred provider types of the workspace
@ -297,8 +273,6 @@ class ProviderManager:
provider_configurations[str(provider_id_entity)] = provider_configuration
self._configurations_cache[tenant_id] = provider_configurations
# Return the encapsulated object
return provider_configurations
@ -362,7 +336,7 @@ class ProviderManager:
if not default_model:
return None
model_provider_factory = ModelProviderFactory(runtime=self._model_runtime)
model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime)
provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name)
return DefaultModelEntity(

View File

@ -139,10 +139,8 @@ class Jieba(BaseKeyword):
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
}
dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database":
if dataset_keyword_table is None:
return
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit()
else:
@ -156,8 +154,7 @@ class Jieba(BaseKeyword):
if dataset_keyword_table:
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
data: Any = keyword_table_dict["__data__"]
return dict(data["table"])
return dict(keyword_table_dict["__data__"]["table"])
else:
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
dataset_keyword_table = DatasetKeywordTable(

View File

@ -1,5 +1,4 @@
import re
from collections.abc import Callable
from operator import itemgetter
from typing import cast
@ -81,14 +80,12 @@ class JiebaKeywordTableHandler:
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
top_k = cast(int | None, kwargs.pop("topK", top_k))
if top_k is None:
top_k = 20
top_k = kwargs.pop("topK", top_k)
cut = getattr(jieba, "cut", None)
if self._lcut:
tokens = self._lcut(sentence)
elif callable(cut):
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
tokens = list(cut(sentence))
else:
tokens = re.findall(r"\w+", sentence)
@ -109,9 +106,9 @@ class JiebaKeywordTableHandler:
"""Extract keywords with JIEBA tfidf."""
keywords = self._tfidf.extract_tags(
sentence=text,
topK=max_keywords_per_chunk or 10,
topK=max_keywords_per_chunk,
)
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(set(keywords)))

View File

@ -158,7 +158,7 @@ class RetrievalService:
)
if futures:
for _ in concurrent.futures.as_completed(futures, timeout=3600):
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
@ -551,7 +551,6 @@ class RetrievalService:
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
for i in child_index_nodes:
assert i.index_node_id
segment_ids.append(i.segment_id)
if i.segment_id in child_chunk_map:
child_chunk_map[i.segment_id].append(i)

View File

@ -39,58 +39,6 @@ class AbstractVectorFactory(ABC):
return index_struct_dict
class _LazyEmbeddings(Embeddings):
"""Lazy proxy that defers materializing the real embedding model.
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
transitively calls ``FeatureService.get_features`` → ``BillingService``
HTTP GETs (see ``provider_manager.py``). Cleanup paths
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
at all, so deferring this until an ``embed_*`` method is actually invoked
keeps cleanup tasks resilient to transient billing-API failures and avoids
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
hiccups.
Existing callers that perform create / search operations are unaffected:
the first ``embed_*`` call materializes the underlying model and the
behavior is identical from that point on.
"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
self._real: Embeddings | None = None
def _ensure(self) -> Embeddings:
if self._real is None:
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model,
)
self._real = CacheEmbedding(embedding_model)
return self._real
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return await self._ensure().aembed_documents(texts)
async def aembed_query(self, text: str) -> list[float]:
return await self._ensure().aembed_query(text)
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
@ -112,11 +60,7 @@ class Vector:
"original_chunk_id",
]
self._dataset = dataset
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
# never transitively trigger billing API calls during ``Vector(dataset)``
# construction. The real embedding model is materialized only when an
# ``embed_*`` method is actually invoked (i.e. create / search paths).
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()

View File

@ -11,7 +11,6 @@ from core.rag.models.document import AttachmentDocument, Document
from extensions.ext_database import db
from graphon.model_runtime.entities.model_entities import ModelType
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
from models.enums import SegmentType
class DatasetDocumentStore:
@ -128,7 +127,6 @@ class DatasetDocumentStore:
if save_child:
if doc.children:
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -139,7 +137,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type=SegmentType.AUTOMATIC,
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
@ -165,7 +163,6 @@ class DatasetDocumentStore:
)
# add new child chunks
for position, child in enumerate(doc.children, start=1):
assert self._document_id
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
@ -176,7 +173,7 @@ class DatasetDocumentStore:
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type=SegmentType.AUTOMATIC,
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)

View File

@ -94,7 +94,6 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
upload_file = extract_setting.upload_file
if not file_path:
@ -105,7 +104,6 @@ class ExtractProcessor:
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
assert upload_file is not None, "upload_file is required"
etl_type = dify_config.ETL_TYPE
extractor: BaseExtractor | None = None
if etl_type == "Unstructured":

View File

@ -28,10 +28,10 @@ class FunctionCallMultiDatasetRouter:
SystemPromptMessage(content="You are a helpful AI assistant."),
UserPromptMessage(content=query),
]
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
result: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
tools=dataset_tools,
stream=False, # pyright: ignore[reportArgumentType]
stream=False,
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
)
usage = result.usage or LLMUsage.empty_usage()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import codecs
import re
from collections.abc import Set as AbstractSet
from collections.abc import Collection
from typing import Any, Literal
from core.model_manager import ModelInstance
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
cls: type[T],
embedding_model_instance: ModelInstance | None,
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
allowed_special: Literal["all"] | set[str] = set(),
disallowed_special: Literal["all"] | Collection[str] = "all",
**kwargs: Any,
) -> T:
def _token_encoder(texts: list[str]) -> list[int]:
@ -40,7 +40,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
return [len(text) for text in texts]
_ = _token_encoder # kept for future token-length wiring
return cls(length_function=_character_encoder, **kwargs)

Some files were not shown because too many files have changed in this diff Show More