mirror of
https://github.com/langgenius/dify.git
synced 2026-05-09 20:08:06 +08:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c0f33d0d7 | |||
| 668e9c7864 | |||
| a7c481ce87 | |||
| 507eb1f52f | |||
| 8e2ab1367b | |||
| 0c568623d7 | |||
| fb7b8dc151 | |||
| 4bc1046f14 | |||
| f2ec17be9b | |||
| 6532b4d161 | |||
| 4bfc4af590 | |||
| 1fb7329327 | |||
| 40ae39a3a3 | |||
| 35c08f7c3d | |||
| 7b6ceaebea | |||
| 35d9b6a0f8 | |||
| d1c1c04615 | |||
| 04ebf8a92f | |||
| 6f3c2fe97b | |||
| 03cd16fc44 | |||
| 3a6901e718 | |||
| 25034612b8 | |||
| 87620050d7 | |||
| e006eb7a4b | |||
| 305de57eff | |||
| 069fdd4894 | |||
| 783dfe38a0 | |||
| 86ba361ff1 | |||
| 591048d7c2 | |||
| 8a62c1d915 | |||
| b083c910b3 | |||
| 9b2a37ceff | |||
| cf5ebe9430 | |||
| 85c3f9cbf8 | |||
| d98fe7916a | |||
| 0b3b0b5ce8 | |||
| eb5ef3dba5 | |||
| a07b32274a | |||
| 2a38df2b7f | |||
| 71e9e8dda6 | |||
| 772f450b29 | |||
| 390f1f74db | |||
| b7bd9c19ed | |||
| e93821af46 | |||
| 9408759954 | |||
| fe9412af5d | |||
| 218ef6a447 | |||
| 501c0b8746 | |||
| 4214583ae5 | |||
| 73771cb58c | |||
| f5f224f49d | |||
| 813da349ec | |||
| fe8510ad1a |
@ -367,7 +367,7 @@ For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
|
||||
@ -200,7 +200,7 @@ When assigned to test a directory/path, test **ALL content** within that path:
|
||||
|
||||
- ✅ **Import real project components** directly (including base components and siblings)
|
||||
- ✅ **Only mock**: API services (`@/service/*`), `next/navigation`, complex context providers
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`) or dify-ui primitives (`@langgenius/dify-ui/*`)
|
||||
- ❌ **DO NOT mock** base components (`@/app/components/base/*`)
|
||||
- ❌ **DO NOT mock** sibling/child components in the same directory
|
||||
|
||||
> See [Test Structure Template](#test-structure-template) for correct import/mock patterns.
|
||||
@ -325,12 +325,12 @@ For more detailed information, refer to:
|
||||
### Reference Examples in Codebase
|
||||
|
||||
- `web/utils/classnames.spec.ts` - Utility function tests
|
||||
- `web/app/components/base/radio/__tests__/index.spec.tsx` - Component tests
|
||||
- `web/app/components/base/button/index.spec.tsx` - Component tests
|
||||
- `web/__mocks__/provider-context.ts` - Mock factory example
|
||||
|
||||
### Project Configuration
|
||||
|
||||
- `web/vite.config.ts` - Vite/Vitest configuration
|
||||
- `web/vitest.config.ts` - Vitest configuration
|
||||
- `web/vitest.setup.ts` - Test environment setup
|
||||
- `web/scripts/analyze-component.js` - Component analysis tool
|
||||
- Modules are not mocked automatically. Global mocks live in `web/vitest.setup.ts` (for example `react-i18next`, `next/image`); mock other modules like `ky` or `mime` locally in test files.
|
||||
|
||||
@ -36,7 +36,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||
|
||||
### Integration vs Mocking
|
||||
|
||||
- [ ] **DO NOT mock base components or dify-ui primitives** (base `Loading`, `Input`, `Badge`; dify-ui `Button`, `Tooltip`, `Dialog`, etc.)
|
||||
- [ ] **DO NOT mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
- [ ] Import real project components instead of mocking
|
||||
- [ ] Only mock: API calls, complex context providers, third-party libs with side effects
|
||||
- [ ] Prefer integration testing when using single spec file
|
||||
@ -73,7 +73,7 @@ Use this checklist when generating or reviewing tests for Dify frontend componen
|
||||
|
||||
### Mocks
|
||||
|
||||
- [ ] **DO NOT mock base components or dify-ui primitives** (`@/app/components/base/*` or `@langgenius/dify-ui/*`)
|
||||
- [ ] **DO NOT mock base components** (`@/app/components/base/*`)
|
||||
- [ ] `vi.clearAllMocks()` in `beforeEach` (not `afterEach`)
|
||||
- [ ] Shared mock state reset in `beforeEach`
|
||||
- [ ] i18n uses global mock (auto-loaded in `web/vitest.setup.ts`); only override locally for custom translations
|
||||
@ -127,7 +127,7 @@ For the current file being tested:
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check`
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
|
||||
@ -2,27 +2,29 @@
|
||||
|
||||
## ⚠️ Important: What NOT to Mock
|
||||
|
||||
### DO NOT Mock Base Components or dify-ui Primitives
|
||||
### DO NOT Mock Base Components
|
||||
|
||||
**Never mock components from `@/app/components/base/` or from `@langgenius/dify-ui/*`** such as:
|
||||
**Never mock components from `@/app/components/base/`** such as:
|
||||
|
||||
- Legacy base (`@/app/components/base/*`): `Loading`, `Spinner`, `Input`, `Badge`, `Tag`
|
||||
- dify-ui primitives (`@langgenius/dify-ui/*`): `Button`, `Tooltip`, `Dialog`, `Popover`, `DropdownMenu`, `ContextMenu`, `Select`, `AlertDialog`, `Toast`
|
||||
- `Loading`, `Spinner`
|
||||
- `Button`, `Input`, `Select`
|
||||
- `Tooltip`, `Modal`, `Dropdown`
|
||||
- `Icon`, `Badge`, `Tag`
|
||||
|
||||
**Why?**
|
||||
|
||||
- These components have their own dedicated tests
|
||||
- Base components will have their own dedicated tests
|
||||
- Mocking them creates false positives (tests pass but real integration fails)
|
||||
- Using real components tests actual integration behavior
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG: Don't mock base components or dify-ui primitives
|
||||
// ❌ WRONG: Don't mock base components
|
||||
vi.mock('@/app/components/base/loading', () => () => <div>Loading</div>)
|
||||
vi.mock('@langgenius/dify-ui/button', () => ({ Button: ({ children }: any) => <button>{children}</button> }))
|
||||
vi.mock('@/app/components/base/button', () => ({ children }: any) => <button>{children}</button>)
|
||||
|
||||
// ✅ CORRECT: Import and use the real components
|
||||
// ✅ CORRECT: Import and use real base components
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { Button } from '@langgenius/dify-ui/button'
|
||||
import Button from '@/app/components/base/button'
|
||||
// They will render normally in tests
|
||||
```
|
||||
|
||||
@ -317,7 +319,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
|
||||
### ✅ DO
|
||||
|
||||
1. **Use real base components and dify-ui primitives** - Import from `@/app/components/base/` or `@langgenius/dify-ui/*` directly
|
||||
1. **Use real base components** - Import from `@/app/components/base/` directly
|
||||
1. **Use real project components** - Prefer importing over mocking
|
||||
1. **Use real Zustand stores** - Set test state via `store.setState()`
|
||||
1. **Reset mocks in `beforeEach`**, not `afterEach`
|
||||
@ -328,7 +330,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
|
||||
### ❌ DON'T
|
||||
|
||||
1. **Don't mock base components or dify-ui primitives** (`Loading`, `Input`, `Button`, `Tooltip`, `Dialog`, etc.)
|
||||
1. **Don't mock base components** (`Loading`, `Button`, `Tooltip`, etc.)
|
||||
1. **Don't mock Zustand store modules** - Use real stores with `setState()`
|
||||
1. Don't mock components you can import directly
|
||||
1. Don't create overly simplified mocks that miss conditional logic
|
||||
@ -340,7 +342,7 @@ const renderWithQueryClient = (ui: React.ReactElement) => {
|
||||
```
|
||||
Need to use a component in test?
|
||||
│
|
||||
├─ Is it from @/app/components/base/* or @langgenius/dify-ui/*?
|
||||
├─ Is it from @/app/components/base/*?
|
||||
│ └─ YES → Import real component, DO NOT mock
|
||||
│
|
||||
├─ Is it a project component?
|
||||
|
||||
19
.github/workflows/anti-slop.yml
vendored
Normal file
19
.github/workflows/anti-slop.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: Anti-Slop PR Check
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited, synchronize]
|
||||
|
||||
permissions:
|
||||
pull-requests: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
anti-slop:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
failure-add-pr-labels: "needs-revision"
|
||||
14
.github/workflows/api-tests.yml
vendored
14
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
api-unit:
|
||||
name: API Unit Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COVERAGE_FILE: coverage-unit
|
||||
defaults:
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
api-integration:
|
||||
name: API Integration Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COVERAGE_FILE: coverage-integration
|
||||
STORAGE_TYPE: opendal
|
||||
@ -84,7 +84,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -105,7 +105,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
api-coverage:
|
||||
name: API Coverage
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- api-unit
|
||||
- api-integration
|
||||
@ -156,7 +156,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
12
.github/workflows/autofix.yml
vendored
12
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
@ -25,7 +25,7 @@ jobs:
|
||||
- name: Check Docker Compose inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
@ -35,7 +35,7 @@ jobs:
|
||||
- name: Check web inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: web-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -48,7 +48,7 @@ jobs:
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -58,7 +58,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@ -123,4 +123,4 @@ jobs:
|
||||
vp exec eslint --concurrency=2 --prune-suppressions --quiet || true
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: autofix-ci/action@c5b2d67aa2274e7b5a18224e8171550871fc7e4a # v1.3.4
|
||||
uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,9 +26,6 @@ jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@ -38,28 +35,28 @@ jobs:
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -73,8 +70,8 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -84,15 +81,16 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: depot/build-push-action@v1
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
@ -110,33 +108,9 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
fork-build-validate:
|
||||
if: github.repository != 'langgenius/dify'
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
create-manifest:
|
||||
needs: build
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
34
.github/workflows/db-migration-test.yml
vendored
34
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@ -110,28 +110,6 @@ jobs:
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
|
||||
# to return (container processes started); it does not wait on healthcheck
|
||||
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
|
||||
# wait the migration runs while InnoDB is still initialising and gets
|
||||
# killed with "Lost connection during query". Poll a real SELECT until it
|
||||
# succeeds.
|
||||
- name: Wait for MySQL to accept queries
|
||||
run: |
|
||||
set +e
|
||||
for i in $(seq 1 60); do
|
||||
if docker run --rm --network host mysql:8.0 \
|
||||
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
|
||||
-e 'SELECT 1' >/dev/null 2>&1; then
|
||||
echo "MySQL ready after ${i}s"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL not ready after 60s; dumping container logs:"
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
|
||||
exit 1
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
|
||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
|
||||
47
.github/workflows/docker-build.yml
vendored
47
.github/workflows/docker-build.yml
vendored
@ -14,69 +14,40 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@v1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@v1
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
|
||||
build-docker-fork:
|
||||
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@98e3b2c9eab4f4f98a95c0c0a3ea5e5e672fd2a8 # v3.10.0
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@5cd29d66b4a8d8e6f4d5dfe2e9329f0b1d446289 # v6.18.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
|
||||
24
.github/workflows/main-ci.yml
vendored
24
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
||||
jobs:
|
||||
pre_job:
|
||||
name: Skip Duplicate Checks
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||
steps:
|
||||
@ -39,7 +39,7 @@ jobs:
|
||||
name: Check Changed Files
|
||||
needs: pre_job
|
||||
if: needs.pre_job.outputs.should_skip != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
@ -141,7 +141,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped API tests
|
||||
run: echo "No API-related changes detected; skipping API tests."
|
||||
@ -154,7 +154,7 @@ jobs:
|
||||
- check-changes
|
||||
- api-tests-run
|
||||
- api-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize API Tests status
|
||||
env:
|
||||
@ -201,7 +201,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped web tests
|
||||
run: echo "No web-related changes detected; skipping web tests."
|
||||
@ -214,7 +214,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-tests-run
|
||||
- web-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize Web Tests status
|
||||
env:
|
||||
@ -260,7 +260,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped web full-stack e2e
|
||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||
@ -273,7 +273,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-e2e-run
|
||||
- web-e2e-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize Web Full-Stack E2E status
|
||||
env:
|
||||
@ -325,7 +325,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped VDB tests
|
||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||
@ -338,7 +338,7 @@ jobs:
|
||||
- check-changes
|
||||
- vdb-tests-run
|
||||
- vdb-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize VDB Tests status
|
||||
env:
|
||||
@ -384,7 +384,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped DB migration tests
|
||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||
@ -397,7 +397,7 @@ jobs:
|
||||
- check-changes
|
||||
- db-migration-test-run
|
||||
- db-migration-test-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize DB Migration Test status
|
||||
env:
|
||||
|
||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
4
.github/workflows/pyrefly-diff.yml
vendored
4
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
@ -24,7 +24,7 @@ jobs:
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
20
.github/workflows/style.yml
vendored
20
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
@ -73,7 +73,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
@ -95,7 +95,7 @@ jobs:
|
||||
- name: Restore ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
id: eslint-cache-restore
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: .eslintcache
|
||||
key: ${{ runner.os }}-eslint-${{ hashFiles('pnpm-lock.yaml', 'eslint.config.mjs', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
@ -110,8 +110,6 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
@ -126,14 +124,14 @@ jobs:
|
||||
|
||||
- name: Save ESLint cache
|
||||
if: steps.changed-files.outputs.any_changed == 'true' && success() && steps.eslint-cache-restore.outputs.cache-hit != 'true'
|
||||
uses: actions/cache/save@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: .eslintcache
|
||||
key: ${{ steps.eslint-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -144,7 +142,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@9426d40962ed5378910ee2e21d5f8c6fcbf2dd96 # v47.0.6
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@ -18,7 +18,7 @@ concurrency:
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
@ -30,7 +30,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6.4.0
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@567fe954a4527e81f132d87d1bdbcc94f7737434 # v1.0.107
|
||||
uses: anthropics/claude-code-action@b47fd721da662d48c5680e154ad16a73ed74d2e0 # v1.0.93
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
|
||||
6
.github/workflows/vdb-tests-full.yml
vendored
6
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -36,7 +36,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -65,7 +65,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
6
.github/workflows/vdb-tests.yml
vendored
6
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@d2bee4f07e8ca410d6b196d00f90c12e7d48c33a # v2.6.0
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@ -28,7 +28,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@ -237,10 +237,6 @@ scripts/stress-test/reports/
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# vitest browser mode attachments (failure screenshots, traces, etc.)
|
||||
.vitest-attachments/
|
||||
**/__screenshots__/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
@ -30,7 +30,7 @@ The codebase is split into:
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
15
README.md
15
README.md
@ -139,6 +139,19 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
Import the dashboard to Grafana, using Dify's PostgreSQL database as data source, to monitor metrics in granularity of apps, tenants, messages, and more.
|
||||
@ -147,7 +160,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
@ -659,11 +659,6 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
@ -714,6 +709,22 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
|
||||
@ -101,11 +101,3 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
```
|
||||
uv run dev/generate_swagger_specs.py --output-dir openapi
|
||||
```
|
||||
|
||||
use https://jsontotable.org/openapi-to-typescript to convert to typescript
|
||||
|
||||
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -520,6 +499,35 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -895,6 +903,17 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1169,6 +1188,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
@ -1400,7 +1427,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -692,32 +692,6 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -38,48 +38,6 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
|
||||
if normalized_record.get("child_chunks") is None:
|
||||
normalized_record["child_chunks"] = []
|
||||
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
@ -117,12 +75,7 @@ class DatasetsHitTestingBase:
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -20,11 +20,11 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@ -34,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
@ -51,11 +56,6 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_recipient_type(form: Form) -> None:
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.CONSOLE):
|
||||
raise NotFoundError("form not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -99,8 +99,10 @@ class ConsoleHumanInputFormApi(Resource):
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
self._ensure_console_recipient_type(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
@ -37,11 +37,6 @@ class TagBindingRemovePayload(BaseModel):
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingItemDeletePayload(BaseModel):
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
|
||||
keyword: str | None = Field(None, description="Search keyword")
|
||||
@ -75,7 +70,6 @@ register_schema_models(
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
TagBindingItemDeletePayload,
|
||||
TagListQueryParam,
|
||||
TagResponse,
|
||||
)
|
||||
@ -158,107 +152,41 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_binding() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=payload.tag_id,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/<uuid:id>")
|
||||
class TagBindingItemApi(Resource):
|
||||
"""Canonical item resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding")
|
||||
@console_ns.doc(params={"id": "Tag ID"})
|
||||
@console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, id):
|
||||
_require_tag_binding_edit_permission()
|
||||
payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_id=str(id),
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class DeprecatedTagBindingCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class DeprecatedTagBindingRemoveApi(Resource):
|
||||
"""Deprecated verb-based alias for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("delete_tag_binding_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_binding()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -595,25 +595,13 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
# Default to the initial phase; any legacy/unexpected client input is
|
||||
# coerced back to `old_email` so we never trust the caller to declare
|
||||
# later phases without a verified predecessor token.
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.phase is not None and args.phase == AccountService.CHANGE_EMAIL_PHASE_NEW:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# The token used to request a new-email code must come from the
|
||||
# old-email verification step. This prevents the bypass described
|
||||
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@ -632,7 +620,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=send_phase,
|
||||
phase=args.phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@ -667,31 +655,12 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Only advance tokens that were minted by the matching send-code step;
|
||||
# refuse tokens that have already progressed or lack a phase marker so
|
||||
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
|
||||
# is strictly enforced.
|
||||
phase_transitions = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token that carries the
|
||||
# upgraded phase so later steps can check it.
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@ -721,29 +690,13 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Only tokens that completed both verification phases may be used to
|
||||
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
|
||||
# the initial send-code step could be replayed directly here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Bind the new email to the token that was mailed and verified, so a
|
||||
# verified token cannot be reused with a different `new_email` value.
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
# Revoke only after all checks pass so failed attempts don't burn a
|
||||
# legitimately verified token.
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@ -1,11 +1,3 @@
|
||||
"""Console workspace endpoint controllers.
|
||||
|
||||
This module exposes workspace-scoped plugin endpoint management APIs. The
|
||||
canonical write routes follow resource-oriented paths, while the historical
|
||||
verb-based aliases stay available as deprecated resources so OpenAPI metadata
|
||||
marks only the legacy paths as deprecated.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
@ -33,12 +25,7 @@ class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(BaseModel):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LegacyEndpointUpdatePayload(EndpointIdPayload):
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
@ -89,7 +76,6 @@ register_schema_models(
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
LegacyEndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
@ -102,60 +88,8 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints")
|
||||
class EndpointCollectionApi(Resource):
|
||||
"""Canonical collection resource for endpoint creation."""
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@ -170,33 +104,22 @@ class EndpointCollectionApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class DeprecatedEndpointCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -267,56 +190,10 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/<string:id>")
|
||||
class EndpointItemApi(Resource):
|
||||
"""Canonical item resource for endpoint updates and deletion."""
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class DeprecatedEndpointDeleteApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for deleting a plugin endpoint. "
|
||||
"Use DELETE /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
@ -329,23 +206,22 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class DeprecatedEndpointUpdateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint updates."""
|
||||
|
||||
@console_ns.doc("update_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating a plugin endpoint. "
|
||||
"Use PATCH /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -357,8 +233,19 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
|
||||
41
api/controllers/openapi/__init__.py
Normal file
41
api/controllers/openapi/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
from . import (
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
33
api/controllers/openapi/_audit.py
Normal file
33
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
|
||||
|
||||
def emit_app_run(*, app_id: str, tenant_id: str, caller_kind: str, mode: str) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
},
|
||||
)
|
||||
143
api/controllers/openapi/_input_schema.py
Normal file
143
api/controllers/openapi/_input_schema.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
112
api/controllers/openapi/_models.py
Normal file
112
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Server-side cap on `limit` query param for any /openapi/v1/* list endpoint.
|
||||
# Sibling endpoints (`/apps`, `/account/sessions`, future routes) all clamp to
|
||||
# this; do not introduce per-endpoint caps without raising the constant.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
tags: list[dict[str, str]] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[dict[str, str]] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
236
api/controllers/openapi/account.py
Normal file
236
api/controllers/openapi/account.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""User-scoped account endpoints. /account is the bearer-authed
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import and_, select, update
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT, PaginationEnvelope
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email,
|
||||
"subject_issuer": ctx.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
}
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email or (account.email if account else None),
|
||||
"account": _account_payload(account) if account else None,
|
||||
"workspaces": [_workspace_payload(m) for m in memberships],
|
||||
"default_workspace_id": default_ws_id,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
_revoke_token_by_id(str(ctx.token_id))
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = db.session.execute(
|
||||
select(
|
||||
OAuthAccessToken.id,
|
||||
OAuthAccessToken.prefix,
|
||||
OAuthAccessToken.client_id,
|
||||
OAuthAccessToken.device_label,
|
||||
OAuthAccessToken.created_at,
|
||||
OAuthAccessToken.last_used_at,
|
||||
OAuthAccessToken.expires_at,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
*_subject_match(ctx),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
OAuthAccessToken.token_hash.is_not(None),
|
||||
OAuthAccessToken.expires_at > now,
|
||||
)
|
||||
)
|
||||
.order_by(OAuthAccessToken.created_at.desc())
|
||||
).all()
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"prefix": r.prefix,
|
||||
"client_id": r.client_id,
|
||||
"device_label": r.device_label,
|
||||
"created_at": _iso(r.created_at),
|
||||
"last_used_at": _iso(r.last_used_at),
|
||||
"expires_at": _iso(r.expires_at),
|
||||
}
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if owns is None:
|
||||
raise NotFound("session not found")
|
||||
|
||||
_revoke_token_by_id(session_id)
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
def _subject_match(ctx: AuthContext) -> tuple:
|
||||
"""Where-clauses that scope a query to the bearer's subject. Works
|
||||
for both account (account_id) and external_sso (email + issuer).
|
||||
"""
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
return (
|
||||
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||
OAuthAccessToken.account_id.is_(None),
|
||||
)
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _revoke_token_by_id(token_id: str) -> None:
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> dict:
|
||||
join, tenant = row
|
||||
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
|
||||
|
||||
|
||||
def _account_payload(account) -> dict:
|
||||
return {"id": str(account.id), "email": account.email, "name": account.name}
|
||||
198
api/controllers/openapi/app_run.py
Normal file
198
api/controllers/openapi/app_run.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import (
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | UUID | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
|
||||
if isinstance(response, tuple):
|
||||
response = response[0]
|
||||
if not isinstance(response, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return response
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
try:
|
||||
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
return blocking_body, 200
|
||||
315
api/controllers/openapi/apps.py
Normal file
315
api/controllers/openapi/apps.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → sets `g.auth_ctx` before `require_scope` reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
raise NotFound("app not found")
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""`mode` is a closed enum — unknown values 422 instead of silently-empty data."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return PaginationEnvelope[AppListRow].build(page=1, limit=0, total=0, items=[]).model_dump(mode="json"), 200
|
||||
|
||||
try:
|
||||
query = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
PaginationEnvelope[AppListRow]
|
||||
.build(page=query.page, limit=query.limit, total=0, items=[])
|
||||
.model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
if parsed_uuid is not None:
|
||||
app = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = PaginationEnvelope[AppListRow].build(page=1, limit=1, total=1, items=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"mode": query.mode.value if query.mode else "",
|
||||
"name": query.name,
|
||||
"status": "normal",
|
||||
}
|
||||
if tag_ids:
|
||||
args["tag_ids"] = tag_ids
|
||||
|
||||
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, args)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[{"name": t.name} for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=int(pagination.total), items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
101
api/controllers/openapi/apps_permitted.py
Normal file
101
api/controllers/openapi/apps_permitted.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""GET /openapi/v1/apps/permitted — external-subject app discovery (EE only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_EXT_SSO,
|
||||
Scope,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
|
||||
class AppPermittedListQuery(BaseModel):
|
||||
"""Strict (`extra='forbid'`) — rejects `workspace_id`/`tag`/etc. that are valid on /apps but not here."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/permitted")
|
||||
class AppPermittedListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED),
|
||||
validate_bearer(accept=ACCEPT_USER_EXT_SSO),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
def get(self):
|
||||
try:
|
||||
query = AppPermittedListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(sa.select(App).where(App.id.in_(page_result.app_ids))).scalars().all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
|
||||
# total/has_more reflect the EE-side allow-list; len(items) may be < limit when local rows are dropped.
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
43
api/controllers/openapi/auth/composition.py
Normal file
43
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
46
api/controllers/openapi/auth/context.py
Normal file
46
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: Scope
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
41
api/controllers/openapi/auth/pipeline.py
Normal file
41
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
131
api/controllers/openapi/auth/steps.py
Normal file
131
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
_extract_bearer, # type: ignore[attr-defined]
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
115
api/controllers/openapi/auth/strategies.py
Normal file
115
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL via the workspace-auth inner API.
|
||||
|
||||
Used when webapp-auth is enabled (EE deployment). The inner-API
|
||||
allowlist is the source of truth.
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_email is None or ctx.app is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=ctx.subject_email,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
9
api/controllers/openapi/index.py
Normal file
9
api/controllers/openapi/index.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
392
api/controllers/openapi/oauth_device.py
Normal file
392
api/controllers/openapi/oauth_device.py
Normal file
@ -0,0 +1,392 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
PREFIX_OAUTH_ACCOUNT,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Request / query schemas
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_ACCOUNT,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0]["id"]
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": {"id": str(account.id), "email": account.email, "name": account.name},
|
||||
"workspaces": workspaces,
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
287
api/controllers/openapi/oauth_device_sso.py
Normal file
287
api/controllers/openapi/oauth_device_sso.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from controllers.openapi import bp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
89
api/controllers/openapi/workspaces.py
Normal file
89
api/controllers/openapi/workspaces.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
return {"workspaces": []}, 200
|
||||
|
||||
rows = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == str(ctx.account_id))
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = g.auth_ctx
|
||||
# External SSO + missing account → never a member of anything; 404.
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
Tenant.id == workspace_id,
|
||||
TenantAccountJoin.account_id == str(ctx.account_id),
|
||||
)
|
||||
).first()
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
}
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
}
|
||||
@ -23,11 +23,9 @@ from .app import (
|
||||
conversation,
|
||||
file,
|
||||
file_preview,
|
||||
human_input_form,
|
||||
message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
from .dataset import (
|
||||
dataset,
|
||||
@ -52,7 +50,6 @@ __all__ = [
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"index",
|
||||
"message",
|
||||
"metadata",
|
||||
@ -61,7 +58,6 @@ __all__ = [
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
api.add_namespace(service_api_ns)
|
||||
|
||||
@ -1,137 +0,0 @@
|
||||
"""
|
||||
Service API human input form endpoints.
|
||||
|
||||
This module exposes app-token authenticated APIs for fetching and submitting
|
||||
paused human input forms in workflow/chatflow runs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
def _ensure_form_belongs_to_app(form: Form, app_model: App) -> None:
|
||||
if form.app_id != app_model.id or form.tenant_id != app_model.tenant_id:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
def _ensure_form_is_allowed_for_service_api(form: Form) -> None:
|
||||
# Keep app-token callers scoped to the public web-form surface; internal HITL
|
||||
# routes must continue to flow through console-only authentication.
|
||||
if not is_recipient_type_allowed_for_surface(form.recipient_type, HumanInputSurface.SERVICE_API):
|
||||
raise NotFound("Form not found")
|
||||
|
||||
|
||||
@service_api_ns.route("/form/human_input/<string:form_token>")
|
||||
class WorkflowHumanInputFormApi(Resource):
|
||||
@service_api_ns.doc("get_human_input_form")
|
||||
@service_api_ns.doc(description="Get a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, form_token: str):
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
service.ensure_form_active(form)
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@service_api_ns.doc("submit_human_input_form")
|
||||
@service_api_ns.doc(description="Submit a paused human input form by token")
|
||||
@service_api_ns.doc(params={"form_token": "Human input form token"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Form submitted successfully",
|
||||
400: "Bad request - invalid submission data",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Form not found",
|
||||
412: "Form already submitted or expired",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, form_token: str):
|
||||
payload = HumanInputFormSubmitPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
_ensure_form_belongs_to_app(form, app_model)
|
||||
_ensure_form_is_allowed_for_service_api(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=payload.action,
|
||||
form_data=payload.inputs,
|
||||
submission_end_user_id=end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFound("Form not found")
|
||||
|
||||
return {}, 200
|
||||
@ -1,142 +0,0 @@
|
||||
"""
|
||||
Service API workflow resume event stream endpoints.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
@service_api_ns.route("/workflow/<string:task_id>/events")
|
||||
class WorkflowEventsApi(Resource):
|
||||
"""Service API for getting workflow execution events after resume."""
|
||||
|
||||
@service_api_ns.doc("get_workflow_events")
|
||||
@service_api_ns.doc(description="Get workflow execution events stream after resume")
|
||||
@service_api_ns.doc(
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "SSE event stream",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Workflow run not found",
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=task_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
workflow_run_entity = workflow_run
|
||||
|
||||
if workflow_run_entity.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run_entity.id,
|
||||
workflow_run=workflow_run_entity,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run_entity,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(
|
||||
app_mode,
|
||||
workflow_run_entity.id,
|
||||
terminal_events=terminal_events,
|
||||
),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
@ -1,12 +1,4 @@
|
||||
"""Service API endpoints for dataset document management.
|
||||
|
||||
The canonical Service API paths use hyphenated route segments. Legacy underscore
|
||||
aliases remain registered for backward compatibility, but they must stay marked
|
||||
deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
@ -125,137 +117,12 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Create a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id_str,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document creation route."""
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@ -271,43 +138,81 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document creation."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a new document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-text instead."
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text through the deprecated underscore alias."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document update route."""
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@ -324,35 +229,62 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document updates."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text through the deprecated underscore alias."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -468,98 +400,15 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -570,9 +419,82 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -886,22 +808,6 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,6 +26,11 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
@ -22,11 +20,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature.
|
||||
|
||||
Optional fields:
|
||||
- prompt: custom instruction prompt.
|
||||
- model: provider/model configuration for suggested question generation.
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
@ -45,27 +39,4 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
prompt = config["suggested_questions_after_answer"].get("prompt")
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
raise ValueError("prompt in suggested_questions_after_answer must be of string type")
|
||||
if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"prompt in suggested_questions_after_answer must be less than or equal to "
|
||||
f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters"
|
||||
)
|
||||
|
||||
if "model" in config["suggested_questions_after_answer"]:
|
||||
model_config = config["suggested_questions_after_answer"]["model"]
|
||||
if not isinstance(model_config, dict):
|
||||
raise ValueError("model in suggested_questions_after_answer must be of object type")
|
||||
|
||||
if "provider" not in model_config or not isinstance(model_config["provider"], str):
|
||||
raise ValueError("provider in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "name" not in model_config or not isinstance(model_config["name"], str):
|
||||
raise ValueError("name in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict):
|
||||
raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type")
|
||||
|
||||
return config, ["suggested_questions_after_answer"]
|
||||
|
||||
@ -34,11 +34,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -659,11 +655,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
@ -12,40 +12,22 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
"workflow_run_id": blocking_response.data.workflow_run_id,
|
||||
"data": paused_data,
|
||||
}
|
||||
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -59,9 +41,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,8 +50,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get("metadata", {})
|
||||
if isinstance(metadata, dict):
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -53,18 +53,14 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
@ -214,13 +210,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -236,39 +226,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> ChatbotAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
status=stream_response.data.status,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
),
|
||||
)
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
@ -289,41 +254,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
answer=self._task_state.answer,
|
||||
metadata=self._message_end_to_stream_response().metadata,
|
||||
created_at=self._message_created_at,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -14,9 +12,11 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@ -13,10 +11,8 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
@classmethod
|
||||
def _cast_blocking_response(cls, response: AppBlockingResponse) -> TBlockingResponse:
|
||||
return cast(TBlockingResponse, response)
|
||||
class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
@ -24,7 +20,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(cls._cast_blocking_response(response))
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -33,7 +29,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(cls._cast_blocking_response(response))
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -43,12 +39,12 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: TBlockingResponse) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: AppBlockingResponse) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -110,13 +106,13 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -130,7 +126,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, JsonValue] | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -14,9 +12,11 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = ChatbotAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -70,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -52,7 +52,6 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -337,26 +336,7 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
pause_reasons = enrich_human_input_pause_reasons(
|
||||
pause_reasons,
|
||||
form_tokens_by_form_id=form_token_by_form_id,
|
||||
expiration_times_by_form_id={
|
||||
form_id: int(expiration_time.timestamp())
|
||||
for form_id, expiration_time in expiration_times_by_form_id.items()
|
||||
},
|
||||
)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -14,15 +12,17 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = CompletionAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
response = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
@ -69,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -99,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
@ -27,7 +26,6 @@ class MessageGenerator:
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
@ -35,5 +33,4 @@ class MessageGenerator:
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
terminal_events=terminal_events,
|
||||
)
|
||||
|
||||
@ -13,9 +13,11 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[WorkflowAppBlockingResponse]):
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@ -24,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter[Workflow
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, object]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -27,11 +27,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
@ -631,11 +627,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if terminal_events is None:
|
||||
if not terminal_events:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
@ -25,11 +25,7 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
@ -616,11 +612,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> (
|
||||
WorkflowAppBlockingResponse
|
||||
| WorkflowAppPausedBlockingResponse
|
||||
| Generator[WorkflowAppStreamResponse, None, None]
|
||||
):
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -9,29 +9,24 @@ from core.app.entities.task_entities import (
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.model_dump())
|
||||
return blocking_response.model_dump()
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@ -42,15 +42,12 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppPausedBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
@ -121,11 +118,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
@ -136,24 +129,19 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[WorkflowAppBlockingResponse, WorkflowAppPausedBlockingResponse]:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
"""
|
||||
human_input_responses: list[HumanInputRequiredResponse] = []
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
@ -164,13 +152,12 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=stream_response.data.paused_nodes,
|
||||
reasons=stream_response.data.reasons,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
return WorkflowAppBlockingResponse(
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
@ -187,44 +174,12 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
if human_input_responses:
|
||||
return self._build_paused_blocking_response_from_human_input(human_input_responses)
|
||||
|
||||
raise ValueError("queue listening stopped unexpectedly.")
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> WorkflowAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
created_at = int(runtime_state.start_at)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=human_input_responses[-1].workflow_run_id,
|
||||
data=WorkflowAppPausedBlockingResponse.Data(
|
||||
id=human_input_responses[-1].workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
outputs={},
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - self._base_task_pipeline.start_at,
|
||||
total_tokens=runtime_state.total_tokens,
|
||||
total_steps=runtime_state.node_run_steps,
|
||||
created_at=created_at,
|
||||
finished_at=None,
|
||||
paused_nodes=paused_nodes,
|
||||
reasons=reasons,
|
||||
),
|
||||
)
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
@ -730,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
@ -296,40 +295,6 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
"""
|
||||
Public pause-reason payload used by blocking responses when only
|
||||
``human_input_required`` events are available.
|
||||
"""
|
||||
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@classmethod
|
||||
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||
return cls(
|
||||
form_id=data.form_id,
|
||||
node_id=data.node_id,
|
||||
node_title=data.node_title,
|
||||
form_content=data.form_content,
|
||||
inputs=data.inputs,
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
@ -390,7 +355,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -447,7 +412,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -809,34 +774,6 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
workflow_run_id: str
|
||||
answer: str
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||
status: WorkflowExecutionStatus
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
data: Data
|
||||
|
||||
|
||||
class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
CompletionAppBlockingResponse entity
|
||||
@ -882,33 +819,6 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowAppPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
WorkflowAppPausedBlockingResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: WorkflowExecutionStatus
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
@ -15,21 +14,8 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -44,12 +30,8 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@ -64,7 +46,6 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -84,8 +65,7 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -104,7 +84,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, Protocol, TypedDict, cast
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select
|
||||
@ -18,6 +18,8 @@ from core.llm_generator.prompts import (
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
@ -39,36 +41,6 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
completion_params: NotRequired[dict[str, object]]
|
||||
|
||||
|
||||
def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]:
|
||||
"""
|
||||
Normalize raw completion params into invocation parameters and stop sequences.
|
||||
|
||||
This mirrors the app-model access path by separating ``stop`` from provider
|
||||
parameters before invocation, then drops non-positive token limits because
|
||||
some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their
|
||||
provider-specific output-token field.
|
||||
"""
|
||||
normalized_parameters = dict(completion_params)
|
||||
stop_value = normalized_parameters.pop("stop", [])
|
||||
if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value):
|
||||
stop = stop_value
|
||||
else:
|
||||
stop = []
|
||||
|
||||
for token_limit_key in ("max_tokens", "max_output_tokens"):
|
||||
token_limit = normalized_parameters.get(token_limit_key)
|
||||
if isinstance(token_limit, int | float) and token_limit <= 0:
|
||||
normalized_parameters.pop(token_limit_key, None)
|
||||
|
||||
return normalized_parameters, stop
|
||||
|
||||
|
||||
class WorkflowServiceInterface(Protocol):
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
pass
|
||||
@ -151,15 +123,8 @@ class LLMGenerator:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
histories: str,
|
||||
*,
|
||||
instruction_prompt: str | None = None,
|
||||
model_config: object | None = None,
|
||||
) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt)
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
@ -168,36 +133,10 @@ class LLMGenerator:
|
||||
|
||||
try:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {}
|
||||
provider = configured_model.get("provider")
|
||||
model_name = configured_model.get("name")
|
||||
use_configured_model = False
|
||||
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
try:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
use_configured_model = True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to use configured suggested-questions model %s/%s, fallback to default model",
|
||||
provider,
|
||||
model_name,
|
||||
exc_info=True,
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
@ -206,29 +145,19 @@ class LLMGenerator:
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
configured_completion_params = configured_model.get("completion_params")
|
||||
if use_configured_model and isinstance(configured_completion_params, dict):
|
||||
model_parameters, stop = _normalize_completion_params(configured_completion_params)
|
||||
elif use_configured_model:
|
||||
model_parameters = {}
|
||||
stop = []
|
||||
else:
|
||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||
model_parameters = {
|
||||
"max_tokens": 2560,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stop = []
|
||||
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
model_parameters={
|
||||
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
questions = output_parser.parse(text_content) if text_content else []
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -3,28 +3,17 @@ import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||
|
||||
@staticmethod
|
||||
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||
if not instruction_prompt or not instruction_prompt.strip():
|
||||
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self._instruction_prompt
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
def parse(self, text: str) -> Sequence[str]:
|
||||
stripped_text = text.strip()
|
||||
action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL)
|
||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||
questions: list[str] = []
|
||||
if action_match is not None:
|
||||
try:
|
||||
@ -34,6 +23,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
else:
|
||||
if isinstance(json_obj, list):
|
||||
questions = [question for question in json_obj if isinstance(question, str)]
|
||||
elif stripped_text:
|
||||
logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200])
|
||||
return questions
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
import os
|
||||
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
@ -95,8 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
# Default prompt and model parameters for suggested questions.
|
||||
DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
# Default prompt for suggested questions (can be overridden by environment variable)
|
||||
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keep each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
@ -104,6 +105,15 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
# Environment variable override for suggested questions prompt
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
|
||||
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
|
||||
)
|
||||
|
||||
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -37,13 +36,11 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@ -437,30 +434,8 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@ -488,19 +463,8 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -70,32 +70,12 @@ class ProviderManager:
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
|
||||
Configuration assembly is cached per manager instance so call chains that
|
||||
share one request-scoped manager can reuse the same provider graph instead
|
||||
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
|
||||
when a long-lived manager needs to observe writes performed within the same
|
||||
instance scope.
|
||||
"""
|
||||
|
||||
decoding_rsa_key: Any | None
|
||||
decoding_cipher_rsa: Any | None
|
||||
_model_runtime: ModelRuntime
|
||||
_configurations_cache: dict[str, ProviderConfigurations]
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
self._configurations_cache = {}
|
||||
|
||||
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
|
||||
"""Drop assembled provider configurations cached on this manager instance."""
|
||||
if tenant_id is None:
|
||||
self._configurations_cache.clear()
|
||||
return
|
||||
|
||||
self._configurations_cache.pop(tenant_id, None)
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@ -134,10 +114,6 @@ class ProviderManager:
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
cached_configurations = self._configurations_cache.get(tenant_id)
|
||||
if cached_configurations is not None:
|
||||
return cached_configurations
|
||||
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
@ -297,8 +273,6 @@ class ProviderManager:
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
self._configurations_cache[tenant_id] = provider_configurations
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
|
||||
@ -139,10 +139,8 @@ class Jieba(BaseKeyword):
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == "database":
|
||||
if dataset_keyword_table is None:
|
||||
return
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -156,8 +154,7 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
@ -81,14 +80,12 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||
if top_k is None:
|
||||
top_k = 20
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||
tokens = list(cut(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
@ -109,9 +106,9 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
@ -158,7 +158,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if futures:
|
||||
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
@ -551,7 +551,6 @@ class RetrievalService:
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
assert i.index_node_id
|
||||
segment_ids.append(i.segment_id)
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
|
||||
@ -39,58 +39,6 @@ class AbstractVectorFactory(ABC):
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class _LazyEmbeddings(Embeddings):
|
||||
"""Lazy proxy that defers materializing the real embedding model.
|
||||
|
||||
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
|
||||
transitively calls ``FeatureService.get_features`` → ``BillingService``
|
||||
HTTP GETs (see ``provider_manager.py``). Cleanup paths
|
||||
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
|
||||
at all, so deferring this until an ``embed_*`` method is actually invoked
|
||||
keeps cleanup tasks resilient to transient billing-API failures and avoids
|
||||
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
|
||||
hiccups.
|
||||
|
||||
Existing callers that perform create / search operations are unaffected:
|
||||
the first ``embed_*`` call materializes the underlying model and the
|
||||
behavior is identical from that point on.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._real: Embeddings | None = None
|
||||
|
||||
def _ensure(self) -> Embeddings:
|
||||
if self._real is None:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
@ -112,11 +60,7 @@ class Vector:
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
|
||||
# never transitively trigger billing API calls during ``Vector(dataset)``
|
||||
# construction. The real embedding model is materialized only when an
|
||||
# ``embed_*`` method is actually invoked (i.e. create / search paths).
|
||||
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ from core.rag.models.document import AttachmentDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.enums import SegmentType
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
@ -128,7 +127,6 @@ class DatasetDocumentStore:
|
||||
if save_child:
|
||||
if doc.children:
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -139,7 +137,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type=SegmentType.AUTOMATIC,
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
@ -165,7 +163,6 @@ class DatasetDocumentStore:
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -176,7 +173,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type=SegmentType.AUTOMATIC,
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
|
||||
@ -94,7 +94,6 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
@ -105,7 +104,6 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: BaseExtractor | None = None
|
||||
if etl_type == "Unstructured":
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user