mirror of
https://github.com/langgenius/dify.git
synced 2026-05-11 04:37:17 +08:00
Compare commits
1 Commits
chore/remo
...
p363-revie
| Author | SHA1 | Date | |
|---|---|---|---|
| 82b1c5bc12 |
@ -367,7 +367,7 @@ For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
@ -9,24 +9,22 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
@ -35,7 +33,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- `web/service/use-*.ts`
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
@ -56,13 +55,9 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
@ -79,37 +74,11 @@ const invoiceQuery = useQuery({
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
@ -36,50 +35,9 @@ function useBadAccessMode(appId: string | undefined) {
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
@ -91,7 +49,7 @@ Use:
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
// Service layer owns cache invalidation.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
@ -166,7 +124,7 @@ When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
|
||||
@ -127,7 +127,7 @@ For the current file being tested:
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check`
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
|
||||
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -6,9 +6,6 @@
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
|
||||
2
.github/actions/setup-web/action.yml
vendored
2
.github/actions/setup-web/action.yml
vendored
@ -4,7 +4,7 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
|
||||
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
@ -6,4 +6,5 @@ web:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
8
.github/workflows/api-tests.yml
vendored
8
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
api-unit:
|
||||
name: API Unit Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COVERAGE_FILE: coverage-unit
|
||||
defaults:
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
api-integration:
|
||||
name: API Integration Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
COVERAGE_FILE: coverage-integration
|
||||
STORAGE_TYPE: opendal
|
||||
@ -98,7 +98,7 @@ jobs:
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
api-coverage:
|
||||
name: API Coverage
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- api-unit
|
||||
- api-integration
|
||||
|
||||
5
.github/workflows/autofix.yml
vendored
5
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
@ -43,6 +43,7 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
@ -113,7 +114,7 @@ jobs:
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Setup web environment
|
||||
if: github.event_name != 'merge_group'
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: ESLint autofix
|
||||
|
||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,9 +26,6 @@ jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@ -38,28 +35,28 @@ jobs:
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -73,8 +70,8 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -84,15 +81,16 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
@ -110,33 +108,9 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
fork-build-validate:
|
||||
if: github.repository != 'langgenius/dify'
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
create-manifest:
|
||||
needs: build
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
26
.github/workflows/db-migration-test.yml
vendored
26
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -110,28 +110,6 @@ jobs:
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
|
||||
# to return (container processes started); it does not wait on healthcheck
|
||||
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
|
||||
# wait the migration runs while InnoDB is still initialising and gets
|
||||
# killed with "Lost connection during query". Poll a real SELECT until it
|
||||
# succeeds.
|
||||
- name: Wait for MySQL to accept queries
|
||||
run: |
|
||||
set +e
|
||||
for i in $(seq 1 60); do
|
||||
if docker run --rm --network host mysql:8.0 \
|
||||
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
|
||||
-e 'SELECT 1' >/dev/null 2>&1; then
|
||||
echo "MySQL ready after ${i}s"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL not ready after 60s; dumping container logs:"
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
|
||||
exit 1
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
|
||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
|
||||
43
.github/workflows/docker-build.yml
vendored
43
.github/workflows/docker-build.yml
vendored
@ -14,59 +14,28 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
|
||||
build-docker-fork:
|
||||
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
@ -79,4 +48,6 @@ jobs:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
|
||||
30
.github/workflows/main-ci.yml
vendored
30
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
||||
jobs:
|
||||
pre_job:
|
||||
name: Skip Duplicate Checks
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||
steps:
|
||||
@ -39,7 +39,7 @@ jobs:
|
||||
name: Check Changed Files
|
||||
needs: pre_job
|
||||
if: needs.pre_job.outputs.should_skip != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
@ -56,9 +56,7 @@ jobs:
|
||||
- 'api/**'
|
||||
- '.github/workflows/api-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.all'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/init-env.sh'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
@ -71,6 +69,7 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@ -84,6 +83,7 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
@ -95,9 +95,7 @@ jobs:
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.all'
|
||||
- 'docker/.env.example'
|
||||
- 'docker/init-env.sh'
|
||||
- 'docker/middleware.env.example'
|
||||
- 'docker/docker-compose.yaml'
|
||||
- 'docker/docker-compose-template.yaml'
|
||||
@ -143,7 +141,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped API tests
|
||||
run: echo "No API-related changes detected; skipping API tests."
|
||||
@ -156,7 +154,7 @@ jobs:
|
||||
- check-changes
|
||||
- api-tests-run
|
||||
- api-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize API Tests status
|
||||
env:
|
||||
@ -203,7 +201,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped web tests
|
||||
run: echo "No web-related changes detected; skipping web tests."
|
||||
@ -216,7 +214,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-tests-run
|
||||
- web-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize Web Tests status
|
||||
env:
|
||||
@ -262,7 +260,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped web full-stack e2e
|
||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||
@ -275,7 +273,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-e2e-run
|
||||
- web-e2e-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize Web Full-Stack E2E status
|
||||
env:
|
||||
@ -327,7 +325,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped VDB tests
|
||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||
@ -340,7 +338,7 @@ jobs:
|
||||
- check-changes
|
||||
- vdb-tests-run
|
||||
- vdb-tests-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize VDB Tests status
|
||||
env:
|
||||
@ -386,7 +384,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Report skipped DB migration tests
|
||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||
@ -399,7 +397,7 @@ jobs:
|
||||
- check-changes
|
||||
- db-migration-test-run
|
||||
- db-migration-test-skip
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Finalize DB Migration Test status
|
||||
env:
|
||||
|
||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
9
.github/workflows/style.yml
vendored
9
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
@ -83,6 +83,7 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
@ -109,8 +110,6 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
@ -132,7 +131,7 @@ jobs:
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
3
.github/workflows/tool-test-sdks.yaml
vendored
3
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,6 +9,7 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -17,7 +18,7 @@ concurrency:
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
|
||||
4
.github/workflows/vdb-tests-full.yml
vendored
4
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -50,7 +50,7 @@ jobs:
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
|
||||
4
.github/workflows/vdb-tests.yml
vendored
4
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
runs-on: depot-ubuntu-24.04
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
@ -47,7 +47,7 @@ jobs:
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
./docker/init-env.sh
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -219,9 +219,6 @@ node_modules
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# generated API OpenAPI specs
|
||||
packages/contracts/openapi/
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
@ -240,10 +237,6 @@ scripts/stress-test/reports/
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# vitest browser mode attachments (failure screenshots, traces, etc.)
|
||||
.vitest-attachments/
|
||||
**/__screenshots__/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
@ -30,7 +30,7 @@ The codebase is split into:
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
26
README.md
26
README.md
@ -76,14 +76,7 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
./init-env.sh
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, initialize `.env`, then run `docker compose up -d` from the `docker` directory.
|
||||
|
||||
```powershell
|
||||
.\init-env.ps1
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
@ -144,7 +137,20 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, edit `docker/.env` after running the initialization script. The full reference remains in [`docker/.env.all`](docker/.env.all). After making any changes, re-run `docker compose up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
@ -154,7 +160,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
@ -659,11 +659,6 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
@ -714,6 +709,22 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
|
||||
@ -101,11 +101,3 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
```
|
||||
uv run dev/generate_swagger_specs.py --output-dir openapi
|
||||
```
|
||||
|
||||
use https://jsontotable.org/openapi-to-typescript to convert to typescript
|
||||
|
||||
@ -113,18 +113,8 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# Generate a random password that satisfies the password policy.
|
||||
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
|
||||
for _ in range(100):
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
valid_password(new_password)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
|
||||
return
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -1400,7 +1379,6 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
"name": "Website Generator"
|
||||
},
|
||||
"app_id": "b53545b1-79ea-4da3-b31a-c39391c6f041",
|
||||
"categories": ["Programming"],
|
||||
"category": "Programming",
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -35,7 +35,7 @@
|
||||
"name": "Investment Analysis Report Copilot"
|
||||
},
|
||||
"app_id": "a23b57fa-85da-49c0-a571-3aff375976c1",
|
||||
"categories": ["Agent"],
|
||||
"category": "Agent",
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Welcome to your personalized Investment Analysis Copilot service, where we delve into the depths of stock analysis to provide you with comprehensive insights. \n",
|
||||
"is_listed": true,
|
||||
@ -51,7 +51,7 @@
|
||||
"name": "Workflow Planning Assistant "
|
||||
},
|
||||
"app_id": "f3303a7d-a81c-404e-b401-1f8711c998c1",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "An assistant that helps you plan and select the right node for a workflow (V0.6.0). ",
|
||||
"is_listed": true,
|
||||
@ -67,7 +67,7 @@
|
||||
"name": "Automated Email Reply "
|
||||
},
|
||||
"app_id": "e9d92058-7d20-4904-892f-75d90bef7587",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Reply emails using Gmail API. It will automatically retrieve email in your inbox and create a response in Gmail. \nConfigure your Gmail API in Google Cloud Console. ",
|
||||
"is_listed": true,
|
||||
@ -83,7 +83,7 @@
|
||||
"name": "Book Translation "
|
||||
},
|
||||
"app_id": "98b87f88-bd22-4d86-8b74-86beba5e0ed4",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "A workflow designed to translate a full book up to 15000 tokens per run. Uses Code node to separate text into chunks and Iteration to translate each chunk. ",
|
||||
"is_listed": true,
|
||||
@ -99,7 +99,7 @@
|
||||
"name": "Python bug fixer"
|
||||
},
|
||||
"app_id": "cae337e6-aec5-4c7b-beca-d6f1a808bd5e",
|
||||
"categories": ["Programming"],
|
||||
"category": "Programming",
|
||||
"copyright": null,
|
||||
"description": null,
|
||||
"is_listed": true,
|
||||
@ -115,7 +115,7 @@
|
||||
"name": "Code Interpreter"
|
||||
},
|
||||
"app_id": "d077d587-b072-4f2c-b631-69ed1e7cdc0f",
|
||||
"categories": ["Programming"],
|
||||
"category": "Programming",
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Code interpreter, clarifying the syntax and semantics of the code.",
|
||||
"is_listed": true,
|
||||
@ -131,7 +131,7 @@
|
||||
"name": "SVG Logo Design "
|
||||
},
|
||||
"app_id": "73fbb5f1-c15d-4d74-9cc8-46d9db9b2cca",
|
||||
"categories": ["Agent"],
|
||||
"category": "Agent",
|
||||
"copyright": "Dify.AI",
|
||||
"description": "Hello, I am your creative partner in bringing ideas to vivid life! I can assist you in creating stunning designs by leveraging abilities of DALL·E 3. ",
|
||||
"is_listed": true,
|
||||
@ -147,7 +147,7 @@
|
||||
"name": "Long Story Generator (Iteration) "
|
||||
},
|
||||
"app_id": "5efb98d7-176b-419c-b6ef-50767391ab62",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "A workflow demonstrating how to use Iteration node to generate long article that is longer than the context length of LLMs. ",
|
||||
"is_listed": true,
|
||||
@ -163,7 +163,7 @@
|
||||
"name": "Text Summarization Workflow"
|
||||
},
|
||||
"app_id": "f00c4531-6551-45ee-808f-1d7903099515",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Based on users' choice, retrieve external knowledge to more accurately summarize articles.",
|
||||
"is_listed": true,
|
||||
@ -179,7 +179,7 @@
|
||||
"name": "YouTube Channel Data Analysis"
|
||||
},
|
||||
"app_id": "be591209-2ca8-410f-8f3b-ca0e530dd638",
|
||||
"categories": ["Agent"],
|
||||
"category": "Agent",
|
||||
"copyright": "Dify.AI",
|
||||
"description": "I am a YouTube Channel Data Analysis Copilot, I am here to provide expert data analysis tailored to your needs. ",
|
||||
"is_listed": true,
|
||||
@ -195,7 +195,7 @@
|
||||
"name": "Article Grading Bot"
|
||||
},
|
||||
"app_id": "a747f7b4-c48b-40d6-b313-5e628232c05f",
|
||||
"categories": ["Writing"],
|
||||
"category": "Writing",
|
||||
"copyright": null,
|
||||
"description": "Assess the quality of articles and text based on user defined criteria. ",
|
||||
"is_listed": true,
|
||||
@ -211,7 +211,7 @@
|
||||
"name": "SEO Blog Generator"
|
||||
},
|
||||
"app_id": "18f3bd03-524d-4d7a-8374-b30dbe7c69d5",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Workflow for retrieving information from the internet, followed by segmented generation of SEO blogs.",
|
||||
"is_listed": true,
|
||||
@ -227,7 +227,7 @@
|
||||
"name": "SQL Creator"
|
||||
},
|
||||
"app_id": "050ef42e-3e0c-40c1-a6b6-a64f2c49d744",
|
||||
"categories": ["Programming"],
|
||||
"category": "Programming",
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "Write SQL from natural language by pasting in your schema with the request.Please describe your query requirements in natural language and select the target database type.",
|
||||
"is_listed": true,
|
||||
@ -243,7 +243,7 @@
|
||||
"name": "Sentiment Analysis "
|
||||
},
|
||||
"app_id": "f06bf86b-d50c-4895-a942-35112dbe4189",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Batch sentiment analysis of text, followed by JSON output of sentiment classification along with scores.",
|
||||
"is_listed": true,
|
||||
@ -259,7 +259,7 @@
|
||||
"name": "Strategic Consulting Expert"
|
||||
},
|
||||
"app_id": "7e8ca1ae-02f2-4b5f-979e-62d19133bee2",
|
||||
"categories": ["Assistant"],
|
||||
"category": "Assistant",
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "I can answer your questions related to strategic marketing.",
|
||||
"is_listed": true,
|
||||
@ -275,7 +275,7 @@
|
||||
"name": "Code Converter"
|
||||
},
|
||||
"app_id": "4006c4b2-0735-4f37-8dbb-fb1a8c5bd87a",
|
||||
"categories": ["Programming"],
|
||||
"category": "Programming",
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "This is an application that provides the ability to convert code snippets in multiple programming languages. You can input the code you wish to convert, select the target programming language, and get the desired output.",
|
||||
"is_listed": true,
|
||||
@ -291,7 +291,7 @@
|
||||
"name": "Question Classifier + Knowledge + Chatbot "
|
||||
},
|
||||
"app_id": "d9f6b733-e35d-4a40-9f38-ca7bbfa009f7",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, a chatbot capable of identifying intents alongside with a knowledge base.",
|
||||
"is_listed": true,
|
||||
@ -307,7 +307,7 @@
|
||||
"name": "AI Front-end interviewer"
|
||||
},
|
||||
"app_id": "127efead-8944-4e20-ba9d-12402eb345e0",
|
||||
"categories": ["HR"],
|
||||
"category": "HR",
|
||||
"copyright": "Copyright 2023 Dify",
|
||||
"description": "A simulated front-end interviewer that tests the skill level of front-end development through questioning.",
|
||||
"is_listed": true,
|
||||
@ -323,7 +323,7 @@
|
||||
"name": "Knowledge Retrieval + Chatbot "
|
||||
},
|
||||
"app_id": "e9870913-dd01-4710-9f06-15d4180ca1ce",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Basic Workflow Template, A chatbot with a knowledge base. ",
|
||||
"is_listed": true,
|
||||
@ -339,7 +339,7 @@
|
||||
"name": "Email Assistant Workflow "
|
||||
},
|
||||
"app_id": "dd5b6353-ae9b-4bce-be6a-a681a12cf709",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "A multifunctional email assistant capable of summarizing, replying, composing, proofreading, and checking grammar.",
|
||||
"is_listed": true,
|
||||
@ -355,7 +355,7 @@
|
||||
"name": "Customer Review Analysis Workflow "
|
||||
},
|
||||
"app_id": "9c0cd31f-4b62-4005-adf5-e3888d08654a",
|
||||
"categories": ["Workflow"],
|
||||
"category": "Workflow",
|
||||
"copyright": null,
|
||||
"description": "Utilize LLM (Large Language Models) to classify customer reviews and forward them to the internal system.",
|
||||
"is_listed": true,
|
||||
|
||||
@ -41,8 +41,7 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
# Decode percent-encoded characters in the path segment
|
||||
filename = urllib.parse.unquote(os.path.basename(url_path))
|
||||
filename = os.path.basename(url_path)
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
@ -9,7 +8,6 @@ from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -59,7 +57,6 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -69,19 +66,22 @@ class AppListQuery(BaseModel):
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported tag_ids type.")
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
@ -91,26 +91,6 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -475,7 +455,7 @@ class AppListApi(Resource):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
@ -712,32 +692,6 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -60,8 +60,7 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -159,13 +158,8 @@ class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersPayload(BaseModel):
|
||||
app_ids: list[str] = Field(default_factory=list, description="App IDs")
|
||||
|
||||
@field_validator("app_ids")
|
||||
@classmethod
|
||||
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
@ -192,7 +186,7 @@ reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -1390,19 +1384,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
app_ids = args.app_ids
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
@ -1410,24 +1404,13 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
users_json_by_app_id: dict[str, Any] = {}
|
||||
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
|
||||
app_id_batch = ordered_accessible_app_ids[
|
||||
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
|
||||
]
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
for app_id in app_id_batch:
|
||||
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users_json_batch = pipe.execute()
|
||||
for app_id, users_json in zip(app_id_batch, users_json_batch):
|
||||
users_json_by_app_id[app_id] = users_json
|
||||
|
||||
results = []
|
||||
for app_id in ordered_accessible_app_ids:
|
||||
users_json = users_json_by_app_id.get(app_id, {})
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
|
||||
@ -75,15 +75,14 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
match value:
|
||||
case FileSegment():
|
||||
return value.value.model_dump()
|
||||
case ArrayFileSegment():
|
||||
return [i.model_dump() for i in value.value]
|
||||
case SegmentGroup():
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
case _:
|
||||
return value.value
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
@ -38,48 +38,6 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
|
||||
if normalized_record.get("child_chunks") is None:
|
||||
normalized_record["child_chunks"] = []
|
||||
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
@ -117,12 +75,7 @@ class DatasetsHitTestingBase:
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -52,7 +52,7 @@ class RecommendedAppResponse(ResponseModel):
|
||||
copyright: str | None = None
|
||||
privacy_policy: str | None = None
|
||||
custom_disclaimer: str | None = None
|
||||
categories: list[str] = Field(default_factory=list)
|
||||
category: str | None = None
|
||||
position: int | None = None
|
||||
is_listed: bool | None = None
|
||||
can_trial: bool | None = None
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -34,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
|
||||
@ -32,7 +32,7 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
@ -152,68 +152,41 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingRemoveApi(Resource):
|
||||
"""Batch resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("remove_tag_bindings")
|
||||
@console_ns.doc(description="Remove one or more tag bindings from a target.")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -8,7 +8,6 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -46,8 +45,6 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@ -325,24 +322,9 @@ class AccountAvatarApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
avatar = args.avatar
|
||||
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.tenant_id != current_tenant_id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
|
||||
@ -1,11 +1,3 @@
|
||||
"""Console workspace endpoint controllers.
|
||||
|
||||
This module exposes workspace-scoped plugin endpoint management APIs. The
|
||||
canonical write routes follow resource-oriented paths, while the historical
|
||||
verb-based aliases stay available as deprecated resources so OpenAPI metadata
|
||||
marks only the legacy paths as deprecated.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
@ -33,12 +25,7 @@ class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(BaseModel):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LegacyEndpointUpdatePayload(EndpointIdPayload):
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
@ -89,7 +76,6 @@ register_schema_models(
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
LegacyEndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
@ -102,60 +88,8 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints")
|
||||
class EndpointCollectionApi(Resource):
|
||||
"""Canonical collection resource for endpoint creation."""
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@ -170,33 +104,22 @@ class EndpointCollectionApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class DeprecatedEndpointCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -267,56 +190,10 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/<string:id>")
|
||||
class EndpointItemApi(Resource):
|
||||
"""Canonical item resource for endpoint updates and deletion."""
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class DeprecatedEndpointDeleteApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for deleting a plugin endpoint. "
|
||||
"Use DELETE /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
@ -329,23 +206,22 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class DeprecatedEndpointUpdateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint updates."""
|
||||
|
||||
@console_ns.doc("update_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating a plugin endpoint. "
|
||||
"Use PATCH /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -357,8 +233,19 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
|
||||
@ -876,10 +876,10 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -8,12 +8,13 @@ paused human input forms in workflow/chatflow runs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@ -25,6 +26,11 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
@ -121,7 +127,7 @@ class WorkflowHumanInputFormApi(Resource):
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
raise InternalServerError("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
|
||||
@ -18,8 +18,6 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -37,14 +35,8 @@ class WorkflowEventsApi(Resource):
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
"include_state_snapshot": "Whether to replay from persisted state snapshot",
|
||||
"continue_on_pause": "Whether to keep the stream open across workflow_paused events",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -107,7 +99,7 @@ class WorkflowEventsApi(Resource):
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
terminal_events = ["workflow_finished"] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
@ -118,7 +110,6 @@ class WorkflowEventsApi(Resource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -100,27 +100,9 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
|
||||
|
||||
tag_ids: list[str] = Field(default_factory=list)
|
||||
tag_id: str | None = None
|
||||
tag_id: str
|
||||
target_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_legacy_tag_id(cls, data: object) -> object:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("tag_ids") and data.get("tag_id"):
|
||||
return {**data, "tag_ids": [data["tag_id"]]}
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tag_ids(self) -> "TagUnbindingPayload":
|
||||
if not self.tag_ids:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
@ -619,11 +601,11 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tags")
|
||||
@service_api_ns.doc(description="Unbind tags from a dataset")
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tags unbound successfully",
|
||||
204: "Tag unbound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
@ -636,7 +618,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,12 +1,4 @@
|
||||
"""Service API endpoints for dataset document management.
|
||||
|
||||
The canonical Service API paths use hyphenated route segments. Legacy underscore
|
||||
aliases remain registered for backward compatibility, but they must stay marked
|
||||
deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
@ -125,137 +117,12 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Create a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id_str,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document creation route."""
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@ -271,43 +138,81 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document creation."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a new document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-text instead."
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text through the deprecated underscore alias."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for the canonical text document update route."""
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@ -324,35 +229,62 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document updates."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text through the deprecated underscore alias."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -468,98 +400,15 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -570,9 +419,82 @@ class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -886,22 +808,6 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,6 +26,11 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
@ -22,11 +20,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature.
|
||||
|
||||
Optional fields:
|
||||
- prompt: custom instruction prompt.
|
||||
- model: provider/model configuration for suggested question generation.
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
@ -45,27 +39,4 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
prompt = config["suggested_questions_after_answer"].get("prompt")
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
raise ValueError("prompt in suggested_questions_after_answer must be of string type")
|
||||
if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"prompt in suggested_questions_after_answer must be less than or equal to "
|
||||
f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters"
|
||||
)
|
||||
|
||||
if "model" in config["suggested_questions_after_answer"]:
|
||||
model_config = config["suggested_questions_after_answer"]["model"]
|
||||
if not isinstance(model_config, dict):
|
||||
raise ValueError("model in suggested_questions_after_answer must be of object type")
|
||||
|
||||
if "provider" not in model_config or not isinstance(model_config["provider"], str):
|
||||
raise ValueError("provider in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "name" not in model_config or not isinstance(model_config["name"], str):
|
||||
raise ValueError("name in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict):
|
||||
raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type")
|
||||
|
||||
return config, ["suggested_questions_after_answer"]
|
||||
|
||||
@ -35,8 +35,8 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
@ -660,9 +660,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
|
||||
@ -3,35 +3,34 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
if isinstance(blocking_response, ChatbotAppPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"event": "workflow_paused",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -45,7 +44,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
}
|
||||
|
||||
response = {
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -60,7 +59,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
|
||||
@ -53,11 +53,10 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
@ -75,7 +74,7 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -218,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
"""
|
||||
@ -238,7 +237,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
) -> Union[ChatbotAppBlockingResponse, ChatbotAppPausedBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@ -250,9 +249,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
return ChatbotAppPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
data=ChatbotAppPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
@ -296,17 +295,18 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
) -> ChatbotAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
reasons = []
|
||||
for response in human_input_responses:
|
||||
reason = response.data.model_dump(mode="json")
|
||||
reason["type"] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
reasons.append(reason)
|
||||
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
return ChatbotAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
data=ChatbotAppPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -70,7 +68,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +99,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -3,8 +3,6 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -110,13 +108,13 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -130,7 +128,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, JsonValue] | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -70,7 +68,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -101,7 +99,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
17
api/core/app/apps/common/pause_reason_serializer.py
Normal file
17
api/core/app/apps/common/pause_reason_serializer.py
Normal file
@ -0,0 +1,17 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
def pause_reason_to_public_dict(reason: PauseReason | Mapping[str, Any]) -> dict[str, Any]:
|
||||
if isinstance(reason, Mapping):
|
||||
data = dict(reason)
|
||||
else:
|
||||
data = dict(reason.model_dump(mode="json"))
|
||||
|
||||
discriminator = data.pop("TYPE", None)
|
||||
if discriminator is not None:
|
||||
data["type"] = discriminator
|
||||
|
||||
return data
|
||||
@ -9,7 +9,9 @@ from typing import Any, NewType, TypedDict, Union
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.common.pause_reason_serializer import pause_reason_to_public_dict
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.human_input_policy import enrich_human_input_pause_reasons
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
@ -52,7 +54,6 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -318,7 +319,7 @@ class WorkflowResponseConverter:
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
pause_reasons = [pause_reason_to_public_dict(reason) for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
@ -337,15 +338,7 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -22,7 +20,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response: dict[str, Any] = {
|
||||
response = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -69,7 +67,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -99,7 +97,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if terminal_events is None:
|
||||
if not terminal_events:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
@ -63,7 +63,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
response_chunk.update(cast(dict[str, object], data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
@ -92,9 +92,9 @@ class WorkflowAppGenerateResponseConverter(
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
response_chunk.update(cast(dict[str, object], data))
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(cast(dict[str, object], sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@ -42,7 +42,6 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
@ -200,11 +199,14 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> WorkflowAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
created_at = int(runtime_state.start_at)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
# Graph runtime `start_at` is a perf-counter value, not an epoch timestamp, so
|
||||
# fallback API payloads need a wall-clock source for `created_at`.
|
||||
created_at = int(time.time())
|
||||
reasons = []
|
||||
for response in human_input_responses:
|
||||
reason = response.data.model_dump(mode="json")
|
||||
reason["type"] = "human_input_required"
|
||||
reasons.append(reason)
|
||||
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
@ -296,40 +295,6 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
"""
|
||||
Public pause-reason payload used by blocking responses when only
|
||||
``human_input_required`` events are available.
|
||||
"""
|
||||
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@classmethod
|
||||
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||
return cls(
|
||||
form_id=data.form_id,
|
||||
node_id=data.node_id,
|
||||
node_title=data.node_title,
|
||||
form_content=data.form_content,
|
||||
inputs=data.inputs,
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
@ -390,7 +355,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -447,7 +412,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -809,7 +774,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
class ChatbotAppPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppPausedBlockingResponse entity
|
||||
"""
|
||||
@ -828,7 +793,7 @@ class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: WorkflowExecutionStatus
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
@ -15,21 +14,8 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -44,12 +30,8 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@ -64,7 +46,6 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -84,8 +65,7 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -104,7 +84,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NotRequired, Protocol, TypedDict, cast
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select
|
||||
@ -18,6 +18,8 @@ from core.llm_generator.prompts import (
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
@ -39,36 +41,6 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
completion_params: NotRequired[dict[str, object]]
|
||||
|
||||
|
||||
def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]:
|
||||
"""
|
||||
Normalize raw completion params into invocation parameters and stop sequences.
|
||||
|
||||
This mirrors the app-model access path by separating ``stop`` from provider
|
||||
parameters before invocation, then drops non-positive token limits because
|
||||
some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their
|
||||
provider-specific output-token field.
|
||||
"""
|
||||
normalized_parameters = dict(completion_params)
|
||||
stop_value = normalized_parameters.pop("stop", [])
|
||||
if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value):
|
||||
stop = stop_value
|
||||
else:
|
||||
stop = []
|
||||
|
||||
for token_limit_key in ("max_tokens", "max_output_tokens"):
|
||||
token_limit = normalized_parameters.get(token_limit_key)
|
||||
if isinstance(token_limit, int | float) and token_limit <= 0:
|
||||
normalized_parameters.pop(token_limit_key, None)
|
||||
|
||||
return normalized_parameters, stop
|
||||
|
||||
|
||||
class WorkflowServiceInterface(Protocol):
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
pass
|
||||
@ -151,15 +123,8 @@ class LLMGenerator:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
histories: str,
|
||||
*,
|
||||
instruction_prompt: str | None = None,
|
||||
model_config: object | None = None,
|
||||
) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt)
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
@ -168,36 +133,10 @@ class LLMGenerator:
|
||||
|
||||
try:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {}
|
||||
provider = configured_model.get("provider")
|
||||
model_name = configured_model.get("name")
|
||||
use_configured_model = False
|
||||
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
try:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
use_configured_model = True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to use configured suggested-questions model %s/%s, fallback to default model",
|
||||
provider,
|
||||
model_name,
|
||||
exc_info=True,
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
@ -206,29 +145,19 @@ class LLMGenerator:
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
configured_completion_params = configured_model.get("completion_params")
|
||||
if use_configured_model and isinstance(configured_completion_params, dict):
|
||||
model_parameters, stop = _normalize_completion_params(configured_completion_params)
|
||||
elif use_configured_model:
|
||||
model_parameters = {}
|
||||
stop = []
|
||||
else:
|
||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||
model_parameters = {
|
||||
"max_tokens": 2560,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stop = []
|
||||
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
model_parameters={
|
||||
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
questions = output_parser.parse(text_content) if text_content else []
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -3,28 +3,17 @@ import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||
|
||||
@staticmethod
|
||||
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||
if not instruction_prompt or not instruction_prompt.strip():
|
||||
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return self._instruction_prompt
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
def parse(self, text: str) -> Sequence[str]:
|
||||
stripped_text = text.strip()
|
||||
action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL)
|
||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||
questions: list[str] = []
|
||||
if action_match is not None:
|
||||
try:
|
||||
@ -34,6 +23,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
else:
|
||||
if isinstance(json_obj, list):
|
||||
questions = [question for question in json_obj if isinstance(question, str)]
|
||||
elif stripped_text:
|
||||
logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200])
|
||||
return questions
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
import os
|
||||
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
@ -95,8 +96,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
# Default prompt and model parameters for suggested questions.
|
||||
DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
# Default prompt for suggested questions (can be overridden by environment variable)
|
||||
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keep each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
@ -104,6 +105,15 @@ DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
# Environment variable override for suggested questions prompt
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
|
||||
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
|
||||
)
|
||||
|
||||
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -37,13 +36,11 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@ -437,30 +434,8 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@ -488,19 +463,8 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -151,12 +151,6 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
|
||||
# parsed ones so we faithfully reproduce the original response. Use Headers.add
|
||||
# rather than dict-style assignment so that repeated headers such as Set-Cookie
|
||||
# (and any other multi-valued header per RFC 9110) are preserved instead of
|
||||
# being overwritten.
|
||||
response.headers.clear()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
@ -164,6 +158,6 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers.add(name, value.strip())
|
||||
response.headers[name] = value.strip()
|
||||
|
||||
return response
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -70,32 +70,12 @@ class ProviderManager:
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
|
||||
Configuration assembly is cached per manager instance so call chains that
|
||||
share one request-scoped manager can reuse the same provider graph instead
|
||||
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
|
||||
when a long-lived manager needs to observe writes performed within the same
|
||||
instance scope.
|
||||
"""
|
||||
|
||||
decoding_rsa_key: Any | None
|
||||
decoding_cipher_rsa: Any | None
|
||||
_model_runtime: ModelRuntime
|
||||
_configurations_cache: dict[str, ProviderConfigurations]
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
self._configurations_cache = {}
|
||||
|
||||
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
|
||||
"""Drop assembled provider configurations cached on this manager instance."""
|
||||
if tenant_id is None:
|
||||
self._configurations_cache.clear()
|
||||
return
|
||||
|
||||
self._configurations_cache.pop(tenant_id, None)
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@ -134,10 +114,6 @@ class ProviderManager:
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
cached_configurations = self._configurations_cache.get(tenant_id)
|
||||
if cached_configurations is not None:
|
||||
return cached_configurations
|
||||
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
@ -297,8 +273,6 @@ class ProviderManager:
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
self._configurations_cache[tenant_id] = provider_configurations
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
@ -445,7 +419,7 @@ class ProviderManager:
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
@ -462,7 +436,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
@ -478,7 +452,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
@ -496,7 +470,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
@ -514,7 +488,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
@ -544,7 +518,7 @@ class ProviderManager:
|
||||
return {}
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
@ -578,7 +552,7 @@ class ProviderManager:
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
@ -608,7 +582,7 @@ class ProviderManager:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
|
||||
@ -139,10 +139,8 @@ class Jieba(BaseKeyword):
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
if keyword_data_source_type == "database":
|
||||
if dataset_keyword_table is None:
|
||||
return
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -156,8 +154,7 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
@ -81,14 +80,12 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||
if top_k is None:
|
||||
top_k = 20
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||
tokens = list(cut(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
@ -109,9 +106,9 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
topK=max_keywords_per_chunk,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
@ -158,7 +158,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if futures:
|
||||
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
@ -217,11 +217,10 @@ class RetrievalService:
|
||||
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||
|
||||
Rules:
|
||||
- If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key;
|
||||
keep the doc with the highest metadata["score"] among duplicates. If a later duplicate
|
||||
has no score, ignore it.
|
||||
- If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content),
|
||||
keeping the first occurrence.
|
||||
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||
(provider, page_content), keeping the first occurrence.
|
||||
"""
|
||||
if not documents:
|
||||
return documents
|
||||
@ -232,10 +231,11 @@ class RetrievalService:
|
||||
order: list[tuple] = []
|
||||
|
||||
for doc in documents:
|
||||
doc_id = (doc.metadata or {}).get("doc_id")
|
||||
is_dify = doc.provider == "dify"
|
||||
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||
|
||||
if doc_id:
|
||||
key = (doc.provider or "dify", doc_id)
|
||||
if is_dify and doc_id:
|
||||
key = ("dify", doc_id)
|
||||
if key not in chosen:
|
||||
chosen[key] = doc
|
||||
order.append(key)
|
||||
@ -551,7 +551,6 @@ class RetrievalService:
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
assert i.index_node_id
|
||||
segment_ids.append(i.segment_id)
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
|
||||
@ -39,58 +39,6 @@ class AbstractVectorFactory(ABC):
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class _LazyEmbeddings(Embeddings):
|
||||
"""Lazy proxy that defers materializing the real embedding model.
|
||||
|
||||
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
|
||||
transitively calls ``FeatureService.get_features`` → ``BillingService``
|
||||
HTTP GETs (see ``provider_manager.py``). Cleanup paths
|
||||
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
|
||||
at all, so deferring this until an ``embed_*`` method is actually invoked
|
||||
keeps cleanup tasks resilient to transient billing-API failures and avoids
|
||||
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
|
||||
hiccups.
|
||||
|
||||
Existing callers that perform create / search operations are unaffected:
|
||||
the first ``embed_*`` call materializes the underlying model and the
|
||||
behavior is identical from that point on.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._real: Embeddings | None = None
|
||||
|
||||
def _ensure(self) -> Embeddings:
|
||||
if self._real is None:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
@ -112,11 +60,7 @@ class Vector:
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
|
||||
# never transitively trigger billing API calls during ``Vector(dataset)``
|
||||
# construction. The real embedding model is materialized only when an
|
||||
# ``embed_*`` method is actually invoked (i.e. create / search paths).
|
||||
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
|
||||
self._embeddings = self._get_embeddings()
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
@ -144,20 +88,8 @@ class Vector:
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
@staticmethod
|
||||
def _filter_empty_text_documents(documents: list[Document]) -> list[Document]:
|
||||
filtered_documents = [document for document in documents if document.page_content.strip()]
|
||||
skipped_count = len(documents) - len(filtered_documents)
|
||||
if skipped_count:
|
||||
logger.warning("skip %d empty documents before vector embedding", skipped_count)
|
||||
return filtered_documents
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
texts = self._filter_empty_text_documents(texts)
|
||||
if not texts:
|
||||
return
|
||||
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
batch_size = 1000
|
||||
@ -215,14 +147,8 @@ class Vector:
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
documents = self._filter_empty_text_documents(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
|
||||
|
||||
@ -11,7 +11,6 @@ from core.rag.models.document import AttachmentDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.enums import SegmentType
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
@ -128,7 +127,6 @@ class DatasetDocumentStore:
|
||||
if save_child:
|
||||
if doc.children:
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -139,7 +137,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type=SegmentType.AUTOMATIC,
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
@ -165,7 +163,6 @@ class DatasetDocumentStore:
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -176,7 +173,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type=SegmentType.AUTOMATIC,
|
||||
type="automatic",
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
|
||||
@ -94,7 +94,6 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
@ -105,7 +104,6 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: BaseExtractor | None = None
|
||||
if etl_type == "Unstructured":
|
||||
|
||||
@ -28,10 +28,10 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False, # pyright: ignore[reportArgumentType]
|
||||
stream=False,
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Set as AbstractSet
|
||||
from collections.abc import Collection
|
||||
from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||
cls: type[T],
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
allowed_special: Literal["all"] | set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
@ -40,7 +40,6 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
_ = _token_encoder # kept for future token-length wiring
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -4,8 +4,7 @@ import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -188,8 +187,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
allowed_special: Literal["all"] | Set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
@ -208,8 +207,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
|
||||
@ -1078,13 +1078,6 @@ class ToolManager:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
if variable_pool:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
|
||||
selector_value = cls._extract_runtime_selector_value(parameter, config)
|
||||
if selector_value is not None:
|
||||
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
|
||||
runtime_parameters[parameter.name] = selector_value
|
||||
continue
|
||||
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
@ -1112,39 +1105,5 @@ class ToolManager:
|
||||
runtime_parameters[parameter.name] = value
|
||||
return runtime_parameters
|
||||
|
||||
@classmethod
|
||||
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
|
||||
if parameter.type not in {
|
||||
ToolParameter.ToolParameterType.MODEL_SELECTOR,
|
||||
ToolParameter.ToolParameterType.APP_SELECTOR,
|
||||
}:
|
||||
return None
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
input_value = config.get("value")
|
||||
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
|
||||
|
||||
if cls._is_selector_value(parameter, config):
|
||||
selector_value = dict(config)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
|
||||
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
)
|
||||
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
|
||||
return isinstance(value.get("app_id"), str)
|
||||
return False
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
||||
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemEncrypter:
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the encrypter.
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters.
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters.
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
EncryptionError: If decryption fails
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(params, dict):
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return params
|
||||
return oauth_params
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an encrypter instance.
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global encrypter instance.
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters using the global encrypter.
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
params: Parameters dictionary
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters using the global encrypter.
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
@ -105,7 +105,7 @@ class Article:
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
|
||||
@ -272,14 +272,6 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
selector_value = _extract_selector_configuration(value)
|
||||
if selector_value is not None:
|
||||
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
|
||||
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
@ -318,28 +310,6 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
|
||||
return None
|
||||
|
||||
|
||||
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
|
||||
input_value = value.get("value")
|
||||
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
|
||||
return dict(input_value)
|
||||
|
||||
if _is_selector_configuration(value):
|
||||
selector_value = dict(value)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return selector_value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
) or isinstance(value.get("app_id"), str)
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from core.workflow.human_input_policy import get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
@ -21,7 +21,6 @@ def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
@ -29,43 +28,23 @@ def load_form_tokens_by_form_id(
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append((recipient.recipient_type, recipient.access_token))
|
||||
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
token = get_preferred_form_token(recipients)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
@ -62,7 +61,7 @@ def enrich_human_input_pause_reasons(
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
updated = dict(reason)
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
if updated.get("type") == "human_input_required":
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user