mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 02:16:51 +08:00
Compare commits
179 Commits
p363-revie
...
feat/stora
| Author | SHA1 | Date | |
|---|---|---|---|
| a14c980e32 | |||
| 1b0d4637b3 | |||
| 936a09c704 | |||
| 5cc62fd1c9 | |||
| 7bc19d8251 | |||
| e845475408 | |||
| 9a8aa6a0c3 | |||
| 76a7f5f4b9 | |||
| 2ff50514c8 | |||
| 7901ac9a97 | |||
| ecd830083a | |||
| 203b3a9499 | |||
| 9331024d91 | |||
| c6a5de3c18 | |||
| cd3327013a | |||
| cd66559ebf | |||
| 8b77ec7f31 | |||
| bb3de5dd32 | |||
| 1e2d309122 | |||
| a24ec60e51 | |||
| 8fd616d27f | |||
| e5bdc40dce | |||
| 376c43e5ac | |||
| 3ebb449d25 | |||
| 5297ac76ec | |||
| bbed1d4a7c | |||
| c804dbed8c | |||
| 00bf3f83f2 | |||
| 7e6745e105 | |||
| d648ce6888 | |||
| f3c3534e33 | |||
| 8967ff34b3 | |||
| 57539792c1 | |||
| 03e227f8f1 | |||
| 506e1a8bc7 | |||
| f8873ec07b | |||
| b2dacf0718 | |||
| 70eb98d6c5 | |||
| b83f296634 | |||
| 5c68f12bb8 | |||
| 4df7c00859 | |||
| 995c43f3dd | |||
| c0431ec843 | |||
| a0af10abc8 | |||
| 8e2b8168be | |||
| 1f29565673 | |||
| 90fe54ca9e | |||
| b43ebf539d | |||
| 853b859032 | |||
| 8f3e42e9c2 | |||
| 1359c03216 | |||
| 4b7dc17546 | |||
| 81090effe2 | |||
| d92c336394 | |||
| cd9daef564 | |||
| 2876839d7e | |||
| 7ba408eebe | |||
| 3708e3eef1 | |||
| ff5c2c57a1 | |||
| 955c25589d | |||
| 54bde0bdf6 | |||
| 87add9a4f3 | |||
| 574d5865f4 | |||
| 458fab1c48 | |||
| 88196c186e | |||
| dcf21a6a84 | |||
| 91f92c7083 | |||
| 0ca339103f | |||
| 5cf741895f | |||
| 11c52e90f6 | |||
| f01e099729 | |||
| 195ff4711d | |||
| fe2f7a8920 | |||
| 3b1458c08f | |||
| 9f47317032 | |||
| e751ec323e | |||
| f1d72eb5d2 | |||
| 44242d03b4 | |||
| ed7ea68f7d | |||
| afbc30c9ed | |||
| 0e55dcb297 | |||
| 25973c7d77 | |||
| 73ecdd5494 | |||
| 6fafeec415 | |||
| d23cefe005 | |||
| 16d408d908 | |||
| 0536549f73 | |||
| d0956039e7 | |||
| 38eb04dc98 | |||
| d2e1da269c | |||
| 1d3498f659 | |||
| b8dea56198 | |||
| e2becd6746 | |||
| 28a26f2d59 | |||
| 8c7393ef46 | |||
| 5a7a955210 | |||
| 3e4849d765 | |||
| 0c280ef708 | |||
| 282561a861 | |||
| cbb4cc5d76 | |||
| 2d6babeeb4 | |||
| 1065a4840a | |||
| b6aa5a7d69 | |||
| 949f930698 | |||
| 65a08ed7ab | |||
| cc4d6db7c8 | |||
| 6b5d6dacb2 | |||
| 89bf75eba9 | |||
| 3a28868a6c | |||
| 4036515abe | |||
| 6c089cab66 | |||
| 818a71d637 | |||
| 3db107edc9 | |||
| 2677d90860 | |||
| 859756c4f6 | |||
| 295fb6e74a | |||
| 2326fb7a83 | |||
| 2d6eaf69f9 | |||
| 3e826c0000 | |||
| b1b977e284 | |||
| 23648141c9 | |||
| d6dee43c09 | |||
| 7efc887e32 | |||
| 8b346e69d9 | |||
| ef7ff3356d | |||
| 7b5c0b5045 | |||
| f00512dd5d | |||
| e6ef774fd5 | |||
| ce50c6cf1c | |||
| 7002512106 | |||
| c3aebb8403 | |||
| 0baefa6163 | |||
| 7bcedcbaab | |||
| 791fc5819d | |||
| 2d09c4788d | |||
| 9bd5c2f8ec | |||
| 5e336c47fd | |||
| be4c828214 | |||
| ec450eb7f9 | |||
| 48e13f65dc | |||
| 38fc2a6574 | |||
| ed8d3f3e8d | |||
| 0c8dec3315 | |||
| 38e831c1b3 | |||
| 1c5d62d98a | |||
| 6b4736bf78 | |||
| c9503fd818 | |||
| 91a1df96cb | |||
| 5b2c5da945 | |||
| b59ecea346 | |||
| 61c0948136 | |||
| f746c7bdf2 | |||
| 2a3deee385 | |||
| 4b6803ba06 | |||
| 4c908c8f39 | |||
| afec528f51 | |||
| 491061b8f4 | |||
| 8b1533438f | |||
| ba924fc97b | |||
| 712e522220 | |||
| 33eebe8cfc | |||
| 2e1b11bdb2 | |||
| d65a6b4810 | |||
| 44a91e344c | |||
| 0fec9af6a6 | |||
| 5e5113e08e | |||
| 73f9a9e7d6 | |||
| 48d23cd744 | |||
| 0b60bf6ef0 | |||
| 3b24d8d2d1 | |||
| 051ba99cd2 | |||
| dc83e8aa09 | |||
| 77f8f2babb | |||
| 77d6c108e7 | |||
| c2a5962023 | |||
| d583b1b835 | |||
| da00de6688 | |||
| 04472cbd58 | |||
| 62d25cbffb |
@ -367,7 +367,7 @@ For each extraction:
|
||||
┌────────────────────────────────────────┐
|
||||
│ 1. Extract code │
|
||||
│ 2. Run: pnpm lint:fix │
|
||||
│ 3. Run: pnpm type-check:tsgo │
|
||||
│ 3. Run: pnpm type-check │
|
||||
│ 4. Run: pnpm test │
|
||||
│ 5. Test functionality manually │
|
||||
│ 6. PASS? → Next extraction │
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
---
|
||||
name: frontend-query-mutation
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions()/mutationOptions() directly or extract a helper or use-* hook, configuring oRPC experimental_defaults/default options, handling conditional queries, cache updates/invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers.
|
||||
---
|
||||
|
||||
# Frontend Query & Mutation
|
||||
@ -9,22 +9,24 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
|
||||
- Keep contract as the single source of truth in `web/contract/*`.
|
||||
- Prefer contract-shaped `queryOptions()` and `mutationOptions()`.
|
||||
- Keep invalidation and mutation flow knowledge in the service layer.
|
||||
- Keep default cache behavior with `consoleQuery`/`marketplaceQuery` setup, and keep business orchestration in feature vertical hooks when direct contract calls are not enough.
|
||||
- Treat `web/service/use-*` query or mutation wrappers as legacy migration targets, not the preferred destination.
|
||||
- Keep abstractions minimal to preserve TypeScript inference.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Identify the change surface.
|
||||
- Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape.
|
||||
- Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations.
|
||||
- Read `references/runtime-rules.md` for conditional queries, default options, cache updates/invalidation, error handling, and legacy migrations.
|
||||
- Read both references when a task spans contract shape and runtime behavior.
|
||||
2. Implement the smallest abstraction that fits the task.
|
||||
- Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site.
|
||||
- Extract a small shared query helper only when multiple call sites share the same extra options.
|
||||
- Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior.
|
||||
- Create or keep feature hooks only for real orchestration or shared domain behavior.
|
||||
- When touching thin `web/service/use-*` wrappers, migrate them away when feasible.
|
||||
3. Preserve Dify conventions.
|
||||
- Keep contract inputs in `{ params, query?, body? }` shape.
|
||||
- Bind invalidation in the service-layer mutation definition.
|
||||
- Bind default cache updates/invalidation in `createTanstackQueryUtils(...experimental_defaults...)`; use feature hooks only for workflows that cannot be expressed as default operation behavior.
|
||||
- Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required.
|
||||
|
||||
## Files Commonly Touched
|
||||
@ -33,7 +35,7 @@ description: Guide for implementing Dify frontend query and mutation patterns wi
|
||||
- `web/contract/marketplace.ts`
|
||||
- `web/contract/router.ts`
|
||||
- `web/service/client.ts`
|
||||
- `web/service/use-*.ts`
|
||||
- legacy `web/service/use-*.ts` files when migrating wrappers away
|
||||
- component and hook call sites using `consoleQuery` or `marketplaceQuery`
|
||||
|
||||
## References
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
interface:
|
||||
display_name: "Frontend Query & Mutation"
|
||||
short_description: "Dify TanStack Query and oRPC patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations."
|
||||
short_description: "Dify TanStack Query, oRPC, and default option patterns"
|
||||
default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, oRPC default options, conditional queries, cache updates/invalidation, or legacy query/mutation migrations."
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
- Core workflow
|
||||
- Query usage decision rule
|
||||
- Mutation usage decision rule
|
||||
- Thin hook decision rule
|
||||
- Anti-patterns
|
||||
- Contract rules
|
||||
- Type export
|
||||
@ -55,9 +56,13 @@ const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({
|
||||
|
||||
1. Default to direct `*.queryOptions(...)` usage at the call site.
|
||||
2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook.
|
||||
3. Create `web/service/use-{domain}.ts` only for orchestration.
|
||||
3. Create or keep feature hooks only for orchestration.
|
||||
- Combine multiple queries or mutations.
|
||||
- Share domain-level derived state or invalidation helpers.
|
||||
- Prefer `web/features/{domain}/hooks/*` for feature-owned workflows.
|
||||
4. Treat `web/service/use-{domain}.ts` as legacy.
|
||||
- Do not create new thin service wrappers for oRPC contracts.
|
||||
- When touching existing wrappers, inline direct `consoleQuery` or `marketplaceQuery` consumption when the wrapper is only a passthrough.
|
||||
|
||||
```typescript
|
||||
const invoicesBaseQueryOptions = () =>
|
||||
@ -74,11 +79,37 @@ const invoiceQuery = useQuery({
|
||||
1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`.
|
||||
2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic.
|
||||
|
||||
```typescript
|
||||
const createTagMutation = useMutation(consoleQuery.tags.create.mutationOptions())
|
||||
```
|
||||
|
||||
## Thin Hook Decision Rule
|
||||
|
||||
Remove thin hooks when they only rename a single oRPC query or mutation helper.
|
||||
Keep hooks when they orchestrate business behavior across multiple operations, own local workflow state, or normalize a feature-specific API.
|
||||
Prefer feature vertical hooks for kept orchestration. Do not move new contract-first wrappers into `web/service/use-*`.
|
||||
|
||||
Use:
|
||||
|
||||
```typescript
|
||||
const deleteTagMutation = useMutation(consoleQuery.tags.delete.mutationOptions())
|
||||
```
|
||||
|
||||
Keep:
|
||||
|
||||
```typescript
|
||||
const applyTagBindingsMutation = useApplyTagBindingsMutation()
|
||||
```
|
||||
|
||||
`useApplyTagBindingsMutation` is acceptable because it coordinates bind and unbind requests, computes deltas, and exposes a feature-level workflow rather than a single endpoint passthrough.
|
||||
|
||||
## Anti-Patterns
|
||||
|
||||
- Do not wrap `useQuery` with `options?: Partial<UseQueryOptions>`.
|
||||
- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case.
|
||||
- Do not create thin `use-*` passthrough hooks for a single endpoint.
|
||||
- Do not create business-layer helpers whose only purpose is to call `consoleQuery.xxx.mutationOptions()` or `queryOptions()`.
|
||||
- Do not introduce new `web/service/use-*` files for oRPC contract passthroughs.
|
||||
- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection.
|
||||
|
||||
## Contract Rules
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
## Table of Contents
|
||||
|
||||
- Conditional queries
|
||||
- oRPC default options
|
||||
- Cache invalidation
|
||||
- Key API guide
|
||||
- `mutate` vs `mutateAsync`
|
||||
@ -35,9 +36,50 @@ function useBadAccessMode(appId: string | undefined) {
|
||||
}
|
||||
```
|
||||
|
||||
## oRPC Default Options
|
||||
|
||||
Use `experimental_defaults` in `createTanstackQueryUtils` when a contract operation should always carry shared TanStack Query behavior, such as default stale time, mutation cache writes, or invalidation.
|
||||
|
||||
Place defaults at the query utility creation point in `web/service/client.ts`:
|
||||
|
||||
```typescript
|
||||
export const consoleQuery = createTanstackQueryUtils(consoleClient, {
|
||||
path: ['console'],
|
||||
experimental_defaults: {
|
||||
tags: {
|
||||
create: {
|
||||
mutationOptions: {
|
||||
onSuccess: (tag, _variables, _result, context) => {
|
||||
context.client.setQueryData(
|
||||
consoleQuery.tags.list.queryKey({
|
||||
input: {
|
||||
query: {
|
||||
type: tag.type,
|
||||
},
|
||||
},
|
||||
}),
|
||||
(oldTags: Tag[] | undefined) => oldTags ? [tag, ...oldTags] : oldTags,
|
||||
)
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
Rules:
|
||||
|
||||
- Keep defaults inline in the `consoleQuery` or `marketplaceQuery` initialization when they need sibling oRPC key builders.
|
||||
- Do not create a wrapper function solely to host `createTanstackQueryUtils`.
|
||||
- Do not split defaults into a vertical feature file if that forces handwritten operation paths such as `generateOperationKey(['console', ...])`.
|
||||
- Keep feature-level orchestration in the feature vertical; keep query utility lifecycle defaults with the query utility.
|
||||
- Prefer call-site callbacks for UI feedback only; shared cache behavior belongs in oRPC defaults when it is tied to a contract operation.
|
||||
|
||||
## Cache Invalidation
|
||||
|
||||
Bind invalidation in the service-layer mutation definition.
|
||||
Bind shared invalidation in oRPC defaults when it is tied to a contract operation.
|
||||
Use feature vertical hooks only for multi-operation workflows or domain orchestration that cannot live in a single operation default.
|
||||
Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate.
|
||||
|
||||
Use:
|
||||
@ -49,7 +91,7 @@ Use:
|
||||
Do not use deprecated `useInvalid` from `use-base.ts`.
|
||||
|
||||
```typescript
|
||||
// Service layer owns cache invalidation.
|
||||
// Feature orchestration owns cache invalidation only when defaults are not enough.
|
||||
export const useUpdateAccessMode = () => {
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
@ -124,7 +166,7 @@ When touching old code, migrate it toward these rules:
|
||||
|
||||
| Old pattern | New pattern |
|
||||
|---|---|
|
||||
| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` |
|
||||
| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition |
|
||||
| `useInvalid(key)` in service wrappers | oRPC defaults, or a feature vertical hook for real orchestration |
|
||||
| component-triggered invalidation after mutation | move invalidation into oRPC defaults or a feature vertical hook |
|
||||
| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` |
|
||||
| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` |
|
||||
|
||||
@ -127,7 +127,7 @@ For the current file being tested:
|
||||
- [ ] Run full directory test: `pnpm test path/to/directory/`
|
||||
- [ ] Check coverage report: `pnpm test:coverage`
|
||||
- [ ] Run `pnpm lint:fix` on all test files
|
||||
- [ ] Run `pnpm type-check:tsgo`
|
||||
- [ ] Run `pnpm type-check`
|
||||
|
||||
## Common Issues to Watch
|
||||
|
||||
|
||||
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -6,6 +6,9 @@
|
||||
|
||||
* @crazywoola @laipz8200 @Yeuoly
|
||||
|
||||
# ESLint suppression file is maintained by autofix.ci pruning.
|
||||
/eslint-suppressions.json
|
||||
|
||||
# CODEOWNERS file
|
||||
/.github/CODEOWNERS @laipz8200 @crazywoola
|
||||
|
||||
|
||||
2
.github/actions/setup-web/action.yml
vendored
2
.github/actions/setup-web/action.yml
vendored
@ -4,7 +4,7 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
|
||||
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
@ -6,5 +6,4 @@ web:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
api-unit:
|
||||
name: API Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-unit
|
||||
defaults:
|
||||
@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
api-integration:
|
||||
name: API Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
env:
|
||||
COVERAGE_FILE: coverage-integration
|
||||
STORAGE_TYPE: opendal
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
|
||||
api-coverage:
|
||||
name: API Coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
needs:
|
||||
- api-unit
|
||||
- api-integration
|
||||
|
||||
5
.github/workflows/autofix.yml
vendored
5
.github/workflows/autofix.yml
vendored
@ -13,7 +13,7 @@ permissions:
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
@ -43,7 +43,6 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
@ -114,7 +113,7 @@ jobs:
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Setup web environment
|
||||
if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true'
|
||||
if: github.event_name != 'merge_group'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: ESLint autofix
|
||||
|
||||
46
.github/workflows/build-push.yml
vendored
46
.github/workflows/build-push.yml
vendored
@ -26,6 +26,9 @@ jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@ -35,28 +38,28 @@ jobs:
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@ -70,8 +73,8 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
@ -81,16 +84,15 @@ jobs:
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=gha,scope=${{ matrix.service_name }}
|
||||
cache-to: type=gha,mode=max,scope=${{ matrix.service_name }}
|
||||
|
||||
- name: Export digest
|
||||
env:
|
||||
@ -108,9 +110,33 @@ jobs:
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
fork-build-validate:
|
||||
if: github.repository != 'langgenius/dify'
|
||||
runs-on: ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Validate Docker image
|
||||
uses: docker/build-push-action@bcafcacb16a39f128d818304e6c9c0c18556b85f # v7.1.0
|
||||
with:
|
||||
push: false
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: linux/amd64
|
||||
|
||||
create-manifest:
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
26
.github/workflows/db-migration-test.yml
vendored
26
.github/workflows/db-migration-test.yml
vendored
@ -9,7 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
db-migration-test-postgres:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
run: uv run --directory api flask upgrade-db
|
||||
|
||||
db-migration-test-mysql:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -110,6 +110,28 @@ jobs:
|
||||
sed -i 's/DB_PORT=5432/DB_PORT=3306/' .env
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=root/' .env
|
||||
|
||||
# hoverkraft-tech/compose-action@v2.6.0 only waits for `docker compose up -d`
|
||||
# to return (container processes started); it does not wait on healthcheck
|
||||
# status. mysql:8.0's first-time init takes 15-30s, so without an explicit
|
||||
# wait the migration runs while InnoDB is still initialising and gets
|
||||
# killed with "Lost connection during query". Poll a real SELECT until it
|
||||
# succeeds.
|
||||
- name: Wait for MySQL to accept queries
|
||||
run: |
|
||||
set +e
|
||||
for i in $(seq 1 60); do
|
||||
if docker run --rm --network host mysql:8.0 \
|
||||
mysql -h 127.0.0.1 -P 3306 -uroot -pdifyai123456 \
|
||||
-e 'SELECT 1' >/dev/null 2>&1; then
|
||||
echo "MySQL ready after ${i}s"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "MySQL not ready after 60s; dumping container logs:"
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile mysql logs --tail=200 db_mysql
|
||||
exit 1
|
||||
|
||||
- name: Run DB Migration
|
||||
env:
|
||||
DEBUG: true
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
|
||||
2
.github/workflows/deploy-enterprise.yml
vendored
2
.github/workflows/deploy-enterprise.yml
vendored
@ -13,7 +13,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/enterprise'
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
|
||||
43
.github/workflows/docker-build.yml
vendored
43
.github/workflows/docker-build.yml
vendored
@ -14,28 +14,59 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
if: github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up Depot CLI
|
||||
uses: depot/setup-action@15c09a5f77a0840ad4bce955686522a257853461 # v1.7.1
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: depot/build-push-action@5f3b3c2e5a00f0093de47f657aeaefcedff27d18 # v1.17.0
|
||||
with:
|
||||
project: ${{ vars.DEPOT_PROJECT_ID }}
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
|
||||
build-docker-fork:
|
||||
if: github.event.pull_request.head.repo.full_name != github.repository
|
||||
runs-on: ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
@ -48,6 +79,4 @@ jobs:
|
||||
push: false
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
|
||||
26
.github/workflows/main-ci.yml
vendored
26
.github/workflows/main-ci.yml
vendored
@ -23,7 +23,7 @@ concurrency:
|
||||
jobs:
|
||||
pre_job:
|
||||
name: Skip Duplicate Checks
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }}
|
||||
steps:
|
||||
@ -39,7 +39,7 @@ jobs:
|
||||
name: Check Changed Files
|
||||
needs: pre_job
|
||||
if: needs.pre_job.outputs.should_skip != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
outputs:
|
||||
api-changed: ${{ steps.changes.outputs.api }}
|
||||
e2e-changed: ${{ steps.changes.outputs.e2e }}
|
||||
@ -69,7 +69,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@ -83,7 +82,6 @@ jobs:
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
@ -141,7 +139,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped API tests
|
||||
run: echo "No API-related changes detected; skipping API tests."
|
||||
@ -154,7 +152,7 @@ jobs:
|
||||
- check-changes
|
||||
- api-tests-run
|
||||
- api-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize API Tests status
|
||||
env:
|
||||
@ -201,7 +199,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web tests
|
||||
run: echo "No web-related changes detected; skipping web tests."
|
||||
@ -214,7 +212,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-tests-run
|
||||
- web-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Tests status
|
||||
env:
|
||||
@ -260,7 +258,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped web full-stack e2e
|
||||
run: echo "No E2E-related changes detected; skipping web full-stack E2E."
|
||||
@ -273,7 +271,7 @@ jobs:
|
||||
- check-changes
|
||||
- web-e2e-run
|
||||
- web-e2e-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize Web Full-Stack E2E status
|
||||
env:
|
||||
@ -325,7 +323,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped VDB tests
|
||||
run: echo "No VDB-related changes detected; skipping VDB tests."
|
||||
@ -338,7 +336,7 @@ jobs:
|
||||
- check-changes
|
||||
- vdb-tests-run
|
||||
- vdb-tests-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize VDB Tests status
|
||||
env:
|
||||
@ -384,7 +382,7 @@ jobs:
|
||||
- pre_job
|
||||
- check-changes
|
||||
if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Report skipped DB migration tests
|
||||
run: echo "No migration-related changes detected; skipping DB migration tests."
|
||||
@ -397,7 +395,7 @@ jobs:
|
||||
- check-changes
|
||||
- db-migration-test-run
|
||||
- db-migration-test-skip
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Finalize DB Migration Test status
|
||||
env:
|
||||
|
||||
2
.github/workflows/pyrefly-diff-comment.yml
vendored
2
.github/workflows/pyrefly-diff-comment.yml
vendored
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with pyrefly diff
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-diff:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
@ -12,7 +12,7 @@ permissions: {}
|
||||
jobs:
|
||||
comment:
|
||||
name: Comment PR with type coverage
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
2
.github/workflows/pyrefly-type-coverage.yml
vendored
2
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -10,7 +10,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
pyrefly-type-coverage:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
name: Validate PR title
|
||||
permissions:
|
||||
pull-requests: read
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
steps:
|
||||
- name: Complete merge group check
|
||||
if: github.event_name == 'merge_group'
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -12,7 +12,7 @@ on:
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
9
.github/workflows/style.yml
vendored
9
.github/workflows/style.yml
vendored
@ -15,7 +15,7 @@ permissions:
|
||||
jobs:
|
||||
python-style:
|
||||
name: Python Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
@ -57,7 +57,7 @@ jobs:
|
||||
|
||||
web-style:
|
||||
name: Web Style
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
@ -83,7 +83,6 @@ jobs:
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
@ -110,6 +109,8 @@ jobs:
|
||||
- name: Web tsslint
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
env:
|
||||
NODE_OPTIONS: --max-old-space-size=4096
|
||||
run: vp run lint:tss
|
||||
|
||||
- name: Web type check
|
||||
@ -131,7 +132,7 @@ jobs:
|
||||
|
||||
superlinter:
|
||||
name: SuperLinter
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
3
.github/workflows/tool-test-sdks.yaml
vendored
3
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,7 +9,6 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
@ -18,7 +17,7 @@ concurrency:
|
||||
jobs:
|
||||
build:
|
||||
name: unit test for Node.js SDK
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
|
||||
defaults:
|
||||
run:
|
||||
|
||||
4
.github/workflows/translate-i18n-claude.yml
vendored
4
.github/workflows/translate-i18n-claude.yml
vendored
@ -35,7 +35,7 @@ concurrency:
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 120
|
||||
|
||||
steps:
|
||||
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@38ec876110f9fbf8b950c79f534430740c3ac009 # v1.0.101
|
||||
uses: anthropics/claude-code-action@fefa07e9c665b7320f08c3b525980457f22f58aa # v1.0.111
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/trigger-i18n-sync.yml
vendored
2
.github/workflows/trigger-i18n-sync.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
|
||||
2
.github/workflows/web-e2e.yml
vendored
2
.github/workflows/web-e2e.yml
vendored
@ -13,7 +13,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Full-Stack E2E
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -16,7 +16,7 @@ concurrency:
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
VITEST_COVERAGE_SCOPE: app-components
|
||||
strategy:
|
||||
@ -54,7 +54,7 @@ jobs:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
@ -92,7 +92,7 @@ jobs:
|
||||
|
||||
dify-ui-test:
|
||||
name: dify-ui Tests
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: depot-ubuntu-24.04-4
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
defaults:
|
||||
|
||||
7
.gitignore
vendored
7
.gitignore
vendored
@ -219,6 +219,9 @@ node_modules
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
# generated API OpenAPI specs
|
||||
packages/contracts/openapi/
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
@ -237,6 +240,10 @@ scripts/stress-test/reports/
|
||||
.playwright-mcp/
|
||||
.serena/
|
||||
|
||||
# vitest browser mode attachments (failure screenshots, traces, etc.)
|
||||
.vitest-attachments/
|
||||
**/__screenshots__/
|
||||
|
||||
# settings
|
||||
*.local.json
|
||||
*.local.md
|
||||
|
||||
@ -30,7 +30,7 @@ The codebase is split into:
|
||||
## Language Style
|
||||
|
||||
- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types.
|
||||
- **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check`, and avoid `any` types.
|
||||
|
||||
## General Practices
|
||||
|
||||
|
||||
22
README.md
22
README.md
@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
./dify-compose up -d
|
||||
```
|
||||
|
||||
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
|
||||
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Seeking help
|
||||
@ -137,20 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
You can now customize the "Suggested Questions After Answer" feature to better fit your use case. For example, to generate longer, more technical questions:
|
||||
|
||||
```bash
|
||||
# In your .env file
|
||||
SUGGESTED_QUESTIONS_PROMPT='Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: ["question1","question2","question3","question4","question5"]'
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS=512
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE=0.3
|
||||
```
|
||||
|
||||
See the [Suggested Questions Configuration Guide](docs/suggested-questions-configuration.md) for detailed examples and usage instructions.
|
||||
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
### Metrics Monitoring with Grafana
|
||||
|
||||
@ -160,7 +148,7 @@ Import the dashboard to Grafana, using Dify's PostgreSQL database as data source
|
||||
|
||||
### Deployment with Kubernetes
|
||||
|
||||
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
If you'd like to configure a highly available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
|
||||
|
||||
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
|
||||
|
||||
@ -659,6 +659,11 @@ INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y
|
||||
MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Creators Platform configuration
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED=true
|
||||
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
@ -709,22 +714,6 @@ SWAGGER_UI_PATH=/swagger-ui.html
|
||||
# Set to false to export dataset IDs as plain text for easier cross-environment import
|
||||
DSL_EXPORT_ENCRYPT_DATASET_ID=true
|
||||
|
||||
# Suggested Questions After Answer Configuration
|
||||
# These environment variables allow customization of the suggested questions feature
|
||||
#
|
||||
# Custom prompt for generating suggested questions (optional)
|
||||
# If not set, uses the default prompt that generates 3 questions under 20 characters each
|
||||
# Example: "Please help me predict the five most likely technical follow-up questions a developer would ask. Focus on implementation details, best practices, and architecture considerations. Keep each question between 40-60 characters. Output must be JSON array: [\"question1\",\"question2\",\"question3\",\"question4\",\"question5\"]"
|
||||
# SUGGESTED_QUESTIONS_PROMPT=
|
||||
|
||||
# Maximum number of tokens for suggested questions generation (default: 256)
|
||||
# Adjust this value for longer questions or more questions
|
||||
# SUGGESTED_QUESTIONS_MAX_TOKENS=256
|
||||
|
||||
# Temperature for suggested questions generation (default: 0.0)
|
||||
# Higher values (0.5-1.0) produce more creative questions, lower values (0.0-0.3) produce more focused questions
|
||||
# SUGGESTED_QUESTIONS_TEMPERATURE=0
|
||||
|
||||
# Tenant isolated task queue configuration
|
||||
TENANT_ISOLATED_TASK_CONCURRENCY=1
|
||||
|
||||
|
||||
@ -101,3 +101,11 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
uv run ruff format ./ # Format code
|
||||
uv run basedpyright . # Type checking
|
||||
```
|
||||
|
||||
## Generate TS stub
|
||||
|
||||
```
|
||||
uv run dev/generate_swagger_specs.py --output-dir openapi
|
||||
```
|
||||
|
||||
use https://jsontotable.org/openapi-to-typescript to convert to typescript
|
||||
|
||||
@ -113,8 +113,18 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
# Generate a random password that satisfies the password policy.
|
||||
# The iteration limit guards against infinite loops caused by unexpected bugs in valid_password.
|
||||
for _ in range(100):
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
valid_password(new_password)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
click.echo(click.style("Failed to generate a valid password. Please try again.", fg="red"))
|
||||
return
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(
|
||||
|
||||
@ -11,7 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
@ -44,7 +44,7 @@ def setup_system_tool_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
@ -94,7 +94,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
|
||||
|
||||
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
|
||||
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
|
||||
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
|
||||
oauth_client_params = encrypt_system_params(client_params_dict)
|
||||
click.echo(click.style("Client params encrypted successfully.", fg="green"))
|
||||
except Exception as e:
|
||||
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
|
||||
|
||||
@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for Creators Platform integration
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable Creators Platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client ID for Creators Platform integration",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@ -1379,6 +1400,7 @@ class FeatureConfig(
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
CreatorsPlatformConfig,
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
|
||||
@ -74,6 +74,11 @@ class StorageConfig(BaseSettings):
|
||||
default="opendal",
|
||||
)
|
||||
|
||||
STORAGE_PATH_PREFIX: str = Field(
|
||||
description="Global path prefix prepended to all storage object keys.",
|
||||
default="",
|
||||
)
|
||||
|
||||
STORAGE_LOCAL_PATH: str = Field(
|
||||
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
|
||||
default="storage",
|
||||
|
||||
@ -41,7 +41,8 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# Try to extract filename from URL
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
url_path = parsed_url.path
|
||||
filename = os.path.basename(url_path)
|
||||
# Decode percent-encoded characters in the path segment
|
||||
filename = urllib.parse.unquote(os.path.basename(url_path))
|
||||
|
||||
# If filename couldn't be extracted, use Content-Disposition header
|
||||
if not filename:
|
||||
|
||||
6
api/controllers/common/human_input.py
Normal file
6
api/controllers/common/human_input.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, JsonValue]
|
||||
action: str
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
@ -8,6 +9,7 @@ from flask_restx import Resource
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.datastructures import MultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -57,6 +59,7 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
@ -66,22 +69,19 @@ class AppListQuery(BaseModel):
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
def validate_tag_ids(cls, value: list[str] | None) -> list[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
items = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
else:
|
||||
raise TypeError("Unsupported tag_ids type.")
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("Unsupported tag_ids type.")
|
||||
|
||||
items = [str(item).strip() for item in value if item and str(item).strip()]
|
||||
if not items:
|
||||
return None
|
||||
|
||||
@ -91,6 +91,26 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
def _normalize_app_list_query_args(query_args: MultiDict[str, str]) -> dict[str, str | list[str]]:
|
||||
normalized: dict[str, str | list[str]] = {}
|
||||
indexed_tag_ids: list[tuple[int, str]] = []
|
||||
|
||||
for key in query_args:
|
||||
match = _TAG_IDS_BRACKET_PATTERN.fullmatch(key)
|
||||
if match:
|
||||
indexed_tag_ids.extend((int(match.group(1)), value) for value in query_args.getlist(key))
|
||||
continue
|
||||
|
||||
value = query_args.get(key)
|
||||
if value is not None:
|
||||
normalized[key] = value
|
||||
|
||||
if indexed_tag_ids:
|
||||
normalized["tag_ids"] = [value for _, value in sorted(indexed_tag_ids)]
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -455,7 +475,7 @@ class AppListApi(Resource):
|
||||
"""Get app list"""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = AppListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = AppListQuery.model_validate(_normalize_app_list_query_args(request.args))
|
||||
args_dict = args.model_dump()
|
||||
|
||||
# get app list
|
||||
@ -692,6 +712,32 @@ class AppExportApi(Resource):
|
||||
return payload.model_dump(mode="json")
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
|
||||
class AppPublishToCreatorsPlatformApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=None)
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
"""Publish app to Creators Platform"""
|
||||
from configs import dify_config
|
||||
from core.helper.creators import get_redirect_url, upload_dsl
|
||||
|
||||
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
|
||||
return {"error": "Creators Platform features are not enabled"}, 403
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
dsl_content = AppDslService.export_dsl(app_model=app_model, include_secret=False)
|
||||
dsl_bytes = dsl_content.encode("utf-8")
|
||||
|
||||
claim_code = upload_dsl(dsl_bytes)
|
||||
redirect_url = get_redirect_url(str(current_user.id), claim_code)
|
||||
|
||||
return {"redirect_url": redirect_url}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/name")
|
||||
class AppNameApi(Resource):
|
||||
@console_ns.doc("check_app_name")
|
||||
|
||||
@ -60,7 +60,8 @@ _file_access_controller = DatabaseFileAccessController()
|
||||
LISTENING_RETRY_IN = 2000
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published"
|
||||
MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS = 50
|
||||
MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS = 1000
|
||||
WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE = 50
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
# Register in dependency order: base models first, then dependent models
|
||||
@ -158,8 +159,13 @@ class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
app_ids: str = Field(..., description="Comma-separated app IDs")
|
||||
class WorkflowOnlineUsersPayload(BaseModel):
|
||||
app_ids: list[str] = Field(default_factory=list, description="App IDs")
|
||||
|
||||
@field_validator("app_ids")
|
||||
@classmethod
|
||||
def normalize_app_ids(cls, app_ids: list[str]) -> list[str]:
|
||||
return list(dict.fromkeys(app_id.strip() for app_id in app_ids if app_id.strip()))
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
@ -186,7 +192,7 @@ reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(WorkflowOnlineUsersPayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -1384,19 +1390,19 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersPayload.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
def post(self):
|
||||
args = WorkflowOnlineUsersPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
app_ids = list(dict.fromkeys(app_id.strip() for app_id in args.app_ids.split(",") if app_id.strip()))
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_QUERY_IDS} app_ids are allowed per request.")
|
||||
app_ids = args.app_ids
|
||||
if len(app_ids) > MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS:
|
||||
raise BadRequest(f"Maximum {MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS} app_ids are allowed per request.")
|
||||
|
||||
if not app_ids:
|
||||
return {"data": []}
|
||||
@ -1404,13 +1410,24 @@ class WorkflowOnlineUsersApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
accessible_app_ids = workflow_service.get_accessible_app_ids(app_ids, current_tenant_id)
|
||||
ordered_accessible_app_ids = [app_id for app_id in app_ids if app_id in accessible_app_ids]
|
||||
|
||||
users_json_by_app_id: dict[str, Any] = {}
|
||||
for start_index in range(0, len(ordered_accessible_app_ids), WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE):
|
||||
app_id_batch = ordered_accessible_app_ids[
|
||||
start_index : start_index + WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE
|
||||
]
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
for app_id in app_id_batch:
|
||||
pipe.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
|
||||
users_json_batch = pipe.execute()
|
||||
for app_id, users_json in zip(app_id_batch, users_json_batch):
|
||||
users_json_by_app_id[app_id] = users_json
|
||||
|
||||
results = []
|
||||
for app_id in app_ids:
|
||||
if app_id not in accessible_app_ids:
|
||||
continue
|
||||
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{app_id}")
|
||||
for app_id in ordered_accessible_app_ids:
|
||||
users_json = users_json_by_app_id.get(app_id, {})
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
|
||||
@ -75,14 +75,15 @@ console_ns.schema_model(
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
match value:
|
||||
case FileSegment():
|
||||
return value.value.model_dump()
|
||||
case ArrayFileSegment():
|
||||
return [i.model_dump() for i in value.value]
|
||||
case SegmentGroup():
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
case _:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable):
|
||||
|
||||
@ -38,6 +38,48 @@ class HitTestingPayload(BaseModel):
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_query(query: Any) -> str:
|
||||
"""Return the user-visible query string from legacy and current response shapes."""
|
||||
if isinstance(query, str):
|
||||
return query
|
||||
|
||||
if isinstance(query, dict):
|
||||
content = query.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
raise ValueError("Invalid hit testing query response")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_hit_testing_records(records: Any) -> list[dict[str, Any]]:
|
||||
"""Coerce nullable collection fields into lists before response validation."""
|
||||
if not isinstance(records, list):
|
||||
return []
|
||||
|
||||
normalized_records: list[dict[str, Any]] = []
|
||||
for record in records:
|
||||
if not isinstance(record, dict):
|
||||
continue
|
||||
|
||||
normalized_record = dict(record)
|
||||
segment = normalized_record.get("segment")
|
||||
if isinstance(segment, dict):
|
||||
normalized_segment = dict(segment)
|
||||
if normalized_segment.get("keywords") is None:
|
||||
normalized_segment["keywords"] = []
|
||||
normalized_record["segment"] = normalized_segment
|
||||
|
||||
if normalized_record.get("child_chunks") is None:
|
||||
normalized_record["child_chunks"] = []
|
||||
|
||||
if normalized_record.get("files") is None:
|
||||
normalized_record["files"] = []
|
||||
|
||||
normalized_records.append(normalized_record)
|
||||
|
||||
return normalized_records
|
||||
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
assert isinstance(current_user, Account)
|
||||
@ -75,7 +117,12 @@ class DatasetsHitTestingBase:
|
||||
attachment_ids=args.get("attachment_ids"),
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
return {
|
||||
"query": DatasetsHitTestingBase._normalize_hit_testing_query(response.get("query")),
|
||||
"records": DatasetsHitTestingBase._normalize_hit_testing_records(
|
||||
marshal(response.get("records", []), hit_testing_record_fields)
|
||||
),
|
||||
}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@ -8,10 +8,10 @@ from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
@ -34,11 +34,6 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
|
||||
@ -32,7 +32,7 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1)
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
@ -152,41 +152,68 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(
|
||||
tag_ids=payload.tag_ids,
|
||||
target_id=payload.target_id,
|
||||
type=payload.type,
|
||||
)
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings")
|
||||
class TagBindingCollectionApi(Resource):
|
||||
"""Canonical collection resource for tag binding creation."""
|
||||
|
||||
@console_ns.doc("create_tag_binding")
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _create_tag_bindings()
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
class TagBindingRemoveApi(Resource):
|
||||
"""Batch resource for tag binding deletion."""
|
||||
|
||||
@console_ns.doc("remove_tag_bindings")
|
||||
@console_ns.doc(description="Remove one or more tag bindings from a target.")
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return _remove_tag_bindings()
|
||||
|
||||
@ -8,6 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@ -45,6 +46,8 @@ from libs.helper import EmailStr, extract_remote_ip, timezone
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
@ -322,9 +325,24 @@ class AccountAvatarApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
avatar = args.avatar
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
if avatar.startswith(("http://", "https://")):
|
||||
return {"avatar_url": avatar}
|
||||
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == avatar).limit(1))
|
||||
if upload_file is None:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.tenant_id != current_tenant_id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
if upload_file.created_by_role != CreatorUserRole.ACCOUNT or upload_file.created_by != current_user.id:
|
||||
raise NotFound("Avatar file not found")
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
|
||||
@ -1,3 +1,11 @@
|
||||
"""Console workspace endpoint controllers.
|
||||
|
||||
This module exposes workspace-scoped plugin endpoint management APIs. The
|
||||
canonical write routes follow resource-oriented paths, while the historical
|
||||
verb-based aliases stay available as deprecated resources so OpenAPI metadata
|
||||
marks only the legacy paths as deprecated.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
@ -25,7 +33,12 @@ class EndpointIdPayload(BaseModel):
|
||||
endpoint_id: str
|
||||
|
||||
|
||||
class EndpointUpdatePayload(EndpointIdPayload):
|
||||
class EndpointUpdatePayload(BaseModel):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
|
||||
class LegacyEndpointUpdatePayload(EndpointIdPayload):
|
||||
settings: dict[str, Any]
|
||||
name: str = Field(min_length=1)
|
||||
|
||||
@ -76,6 +89,7 @@ register_schema_models(
|
||||
EndpointCreatePayload,
|
||||
EndpointIdPayload,
|
||||
EndpointUpdatePayload,
|
||||
LegacyEndpointUpdatePayload,
|
||||
EndpointListQuery,
|
||||
EndpointListForPluginQuery,
|
||||
EndpointCreateResponse,
|
||||
@ -88,8 +102,60 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class EndpointCreateApi(Resource):
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints")
|
||||
class EndpointCollectionApi(Resource):
|
||||
"""Canonical collection resource for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint")
|
||||
@console_ns.doc(description="Create a new plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@ -104,22 +170,33 @@ class EndpointCreateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
return _create_endpoint()
|
||||
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
class DeprecatedEndpointCreateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint creation."""
|
||||
|
||||
@console_ns.doc("create_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a plugin endpoint. Use POST /workspaces/current/endpoints instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointCreatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint created successfully",
|
||||
console_ns.models[EndpointCreateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -190,10 +267,56 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class EndpointDeleteApi(Resource):
|
||||
@console_ns.route("/workspaces/current/endpoints/<string:id>")
|
||||
class EndpointItemApi(Resource):
|
||||
"""Canonical item resource for endpoint updates and deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint")
|
||||
@console_ns.doc(description="Delete a plugin endpoint")
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint deleted successfully",
|
||||
console_ns.models[EndpointDeleteResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
@console_ns.doc(params={"id": {"description": "Endpoint ID", "type": "string", "required": True}})
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
console_ns.models[EndpointUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.response(403, "Admin privileges required")
|
||||
@setup_required
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
class DeprecatedEndpointDeleteApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint deletion."""
|
||||
|
||||
@console_ns.doc("delete_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for deleting a plugin endpoint. "
|
||||
"Use DELETE /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[EndpointIdPayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
@ -206,22 +329,23 @@ class EndpointDeleteApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
class EndpointUpdateApi(Resource):
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@console_ns.expect(console_ns.models[EndpointUpdatePayload.__name__])
|
||||
class DeprecatedEndpointUpdateApi(Resource):
|
||||
"""Deprecated verb-based alias for endpoint updates."""
|
||||
|
||||
@console_ns.doc("update_endpoint_deprecated")
|
||||
@console_ns.doc(deprecated=True)
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating a plugin endpoint. "
|
||||
"Use PATCH /workspaces/current/endpoints/{id} instead."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[LegacyEndpointUpdatePayload.__name__])
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Endpoint updated successfully",
|
||||
@ -233,19 +357,8 @@ class EndpointUpdateApi(Resource):
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
endpoint_id=args.endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
)
|
||||
}
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
|
||||
@ -8,13 +8,12 @@ paused human input forms in workflow/chatflow runs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@ -26,11 +25,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, HumanInputFormSubmitPayload)
|
||||
|
||||
|
||||
@ -127,7 +121,7 @@ class WorkflowHumanInputFormApi(Resource):
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%s", form.id)
|
||||
raise InternalServerError("Form recipient type is invalid")
|
||||
raise BadRequest("Form recipient type is invalid")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
|
||||
@ -18,6 +18,8 @@ from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from core.workflow.human_input_policy import HumanInputSurface
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
@ -35,8 +37,14 @@ class WorkflowEventsApi(Resource):
|
||||
params={
|
||||
"task_id": "Workflow run ID",
|
||||
"user": "End user identifier (query param)",
|
||||
"include_state_snapshot": "Whether to replay from persisted state snapshot",
|
||||
"continue_on_pause": "Whether to keep the stream open across workflow_paused events",
|
||||
"include_state_snapshot": (
|
||||
"Whether to replay from persisted state snapshot, "
|
||||
'specify `"true"` to include a status snapshot of executed nodes'
|
||||
),
|
||||
"continue_on_pause": (
|
||||
"Whether to keep the stream open across workflow_paused events,"
|
||||
'specify `"true"` to keep the stream open for `workflow_paused` events.'
|
||||
),
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(
|
||||
@ -99,7 +107,7 @@ class WorkflowEventsApi(Resource):
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
continue_on_pause = request.args.get("continue_on_pause", "false").lower() == "true"
|
||||
terminal_events = ["workflow_finished"] if continue_on_pause else None
|
||||
terminal_events: list[StreamEvent] | None = [] if continue_on_pause else None
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
@ -110,6 +118,7 @@ class WorkflowEventsApi(Resource):
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
human_input_surface=HumanInputSurface.SERVICE_API,
|
||||
close_on_pause=not continue_on_pause,
|
||||
)
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel):
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str
|
||||
"""Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally."""
|
||||
|
||||
tag_ids: list[str] = Field(default_factory=list)
|
||||
tag_id: str | None = None
|
||||
target_id: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_legacy_tag_id(cls, data: object) -> object:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("tag_ids") and data.get("tag_id"):
|
||||
return {**data, "tag_ids": [data["tag_id"]]}
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_tag_ids(self) -> "TagUnbindingPayload":
|
||||
if not self.tag_ids:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc("unbind_dataset_tags")
|
||||
@service_api_ns.doc(description="Unbind tags from a dataset")
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
204: "Tag unbound successfully",
|
||||
204: "Tags unbound successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,4 +1,12 @@
|
||||
"""Service API endpoints for dataset document management.
|
||||
|
||||
The canonical Service API paths use hyphenated route segments. Legacy underscore
|
||||
aliases remain registered for backward compatibility, but they must stay marked
|
||||
deprecated in generated API docs so clients migrate toward the canonical paths.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from contextlib import ExitStack
|
||||
from typing import Self
|
||||
from uuid import UUID
|
||||
@ -117,12 +125,137 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/document/create_by_text",
|
||||
"/datasets/<uuid:dataset_id>/document/create-by-text",
|
||||
)
|
||||
def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Create a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id_str, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id_str,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id_str
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from text for both canonical and legacy routes."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create-by-text")
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
"""Resource for the canonical text document creation route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@ -138,81 +271,43 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text."""
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/document/create_by_text")
|
||||
class DeprecatedDocumentAddByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document creation."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for creating a new document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/document/create-by-text instead."
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document created successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Create document by text through the deprecated underscore alias."""
|
||||
return _create_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
|
||||
)
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text")
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
"""Resource for the canonical text document update route."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@ -229,62 +324,35 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1)
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
|
||||
class DeprecatedDocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Deprecated resource alias for text document updates."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by providing text content. "
|
||||
"Use /datasets/{dataset_id}/documents/{document_id}/update-by-text instead."
|
||||
)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if args.get("text"):
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=current_user,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch}
|
||||
return documents_and_batch_fields, 200
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text through the deprecated underscore alias."""
|
||||
return _update_document_by_text(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -400,15 +468,98 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID) -> tuple[Mapping[str, object], int]:
|
||||
"""Update a document from an uploaded file for canonical and deprecated routes."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
tenant_id_str = str(tenant_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id_str, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args: dict[str, object] = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
|
||||
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
|
||||
)
|
||||
class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
class DeprecatedDocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Deprecated resource aliases for file document updates."""
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc("update_document_by_file_deprecated")
|
||||
@service_api_ns.doc(deprecated=True)
|
||||
@service_api_ns.doc(
|
||||
description=(
|
||||
"Deprecated legacy alias for updating an existing document by uploading a file. "
|
||||
"Use PATCH /datasets/{dataset_id}/documents/{document_id} instead."
|
||||
)
|
||||
)
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
@ -419,82 +570,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
args = {}
|
||||
if "data" in request.form:
|
||||
args = json.loads(request.form["data"])
|
||||
if "doc_form" not in args:
|
||||
args["doc_form"] = dataset.chunk_structure or "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
if "file" in request.files:
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
user=current_user,
|
||||
source="datasets",
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
}
|
||||
args["data_source"] = data_source
|
||||
# validate args
|
||||
args["original_document_id"] = str(document_id)
|
||||
|
||||
knowledge_config = KnowledgeConfig.model_validate(args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
try:
|
||||
documents, _ = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
document = documents[0]
|
||||
documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": document.batch}
|
||||
return documents_and_batch_fields, 200
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file through the deprecated file-update aliases."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents")
|
||||
@ -808,6 +886,22 @@ class DocumentApi(DatasetApiResource):
|
||||
|
||||
return response
|
||||
|
||||
@service_api_ns.doc("update_document_by_file")
|
||||
@service_api_ns.doc(description="Update an existing document by uploading a file")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@service_api_ns.doc(
|
||||
responses={
|
||||
200: "Document updated successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "Document not found",
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by file on the canonical document resource."""
|
||||
return _update_document_by_file(tenant_id=tenant_id, dataset_id=dataset_id, document_id=document_id)
|
||||
|
||||
@service_api_ns.doc("delete_document")
|
||||
@service_api_ns.doc(description="Delete a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
|
||||
@ -23,7 +23,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@ -69,12 +69,12 @@ class AudioApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert audio to text"""
|
||||
file = request.files["file"]
|
||||
|
||||
try:
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user)
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user.external_user_id)
|
||||
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
@ -117,7 +117,7 @@ class TextApi(WebApiResource):
|
||||
500: "Internal Server Error",
|
||||
}
|
||||
)
|
||||
def post(self, app_model: App, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
|
||||
@ -9,11 +9,11 @@ from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
@ -26,11 +26,6 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
inputs: dict
|
||||
action: str
|
||||
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH = 1000
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
@ -20,7 +22,11 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
Validate and set defaults for suggested questions feature.
|
||||
|
||||
Optional fields:
|
||||
- prompt: custom instruction prompt.
|
||||
- model: provider/model configuration for suggested question generation.
|
||||
|
||||
:param config: app model config args
|
||||
"""
|
||||
@ -39,4 +45,27 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
raise ValueError("enabled in suggested_questions_after_answer must be of boolean type")
|
||||
|
||||
prompt = config["suggested_questions_after_answer"].get("prompt")
|
||||
if prompt is not None and not isinstance(prompt, str):
|
||||
raise ValueError("prompt in suggested_questions_after_answer must be of string type")
|
||||
if isinstance(prompt, str) and len(prompt) > CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH:
|
||||
raise ValueError(
|
||||
f"prompt in suggested_questions_after_answer must be less than or equal to "
|
||||
f"{CUSTOM_FOLLOW_UP_PROMPT_MAX_LENGTH} characters"
|
||||
)
|
||||
|
||||
if "model" in config["suggested_questions_after_answer"]:
|
||||
model_config = config["suggested_questions_after_answer"]["model"]
|
||||
if not isinstance(model_config, dict):
|
||||
raise ValueError("model in suggested_questions_after_answer must be of object type")
|
||||
|
||||
if "provider" not in model_config or not isinstance(model_config["provider"], str):
|
||||
raise ValueError("provider in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "name" not in model_config or not isinstance(model_config["name"], str):
|
||||
raise ValueError("name in suggested_questions_after_answer.model must be of string type")
|
||||
|
||||
if "completion_params" in model_config and not isinstance(model_config["completion_params"], dict):
|
||||
raise ValueError("completion_params in suggested_questions_after_answer.model must be of object type")
|
||||
|
||||
return config, ["suggested_questions_after_answer"]
|
||||
|
||||
@ -35,8 +35,8 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
@ -660,7 +660,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]
|
||||
ChatbotAppBlockingResponse
|
||||
| AdvancedChatPausedBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse, None, None]
|
||||
):
|
||||
"""
|
||||
Handle response.
|
||||
|
||||
@ -3,34 +3,35 @@ from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
AppStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse]
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
if isinstance(blocking_response, ChatbotAppPausedBlockingResponse):
|
||||
if isinstance(blocking_response, AdvancedChatPausedBlockingResponse):
|
||||
paused_data = blocking_response.data.model_dump(mode="json")
|
||||
return {
|
||||
"event": "workflow_paused",
|
||||
"event": StreamEvent.WORKFLOW_PAUSED.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -44,7 +45,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
}
|
||||
|
||||
response = {
|
||||
"event": "message",
|
||||
"event": StreamEvent.MESSAGE.value,
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
@ -59,7 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | ChatbotAppPausedBlockingResponse
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
|
||||
@ -53,10 +53,11 @@ from core.app.entities.queue_entities import (
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
@ -74,7 +75,7 @@ from core.repositories.human_input_repository import HumanInputFormRepositoryImp
|
||||
from core.workflow.file_reference import resolve_file_record_id
|
||||
from core.workflow.system_variables import build_system_variables
|
||||
from extensions.ext_database import db
|
||||
from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from graphon.entities.pause_reason import HumanInputRequired
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -217,7 +218,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppPausedBlockingResponse,
|
||||
AdvancedChatPausedBlockingResponse,
|
||||
Generator[ChatbotAppStreamResponse, None, None],
|
||||
]:
|
||||
"""
|
||||
@ -237,7 +238,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, ChatbotAppPausedBlockingResponse]:
|
||||
) -> Union[ChatbotAppBlockingResponse, AdvancedChatPausedBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@ -249,9 +250,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
elif isinstance(stream_response, HumanInputRequiredResponse):
|
||||
human_input_responses.append(stream_response)
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
return ChatbotAppPausedBlockingResponse(
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
data=ChatbotAppPausedBlockingResponse.Data(
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
@ -295,18 +296,17 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _build_paused_blocking_response_from_human_input(
|
||||
self, human_input_responses: list[HumanInputRequiredResponse]
|
||||
) -> ChatbotAppPausedBlockingResponse:
|
||||
) -> AdvancedChatPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
reasons = []
|
||||
for response in human_input_responses:
|
||||
reason = response.data.model_dump(mode="json")
|
||||
reason["type"] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
reasons.append(reason)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return ChatbotAppPausedBlockingResponse(
|
||||
return AdvancedChatPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppPausedBlockingResponse.Data(
|
||||
data=AdvancedChatPausedBlockingResponse.Data(
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -68,7 +70,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -99,7 +101,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
@ -108,13 +110,13 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, JsonValue]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
error_responses: dict[type[Exception], dict[str, JsonValue]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -128,7 +130,7 @@ class AppGenerateResponseConverter[TBlockingResponse: AppBlockingResponse](ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data: dict[str, Any] | None = None
|
||||
data: dict[str, JsonValue] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -68,7 +70,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
@ -99,7 +101,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
def pause_reason_to_public_dict(reason: PauseReason | Mapping[str, Any]) -> dict[str, Any]:
|
||||
if isinstance(reason, Mapping):
|
||||
data = dict(reason)
|
||||
else:
|
||||
data = dict(reason.model_dump(mode="json"))
|
||||
|
||||
discriminator = data.pop("TYPE", None)
|
||||
if discriminator is not None:
|
||||
data["type"] = discriminator
|
||||
|
||||
return data
|
||||
@ -9,9 +9,7 @@ from typing import Any, NewType, TypedDict, Union
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.common.pause_reason_serializer import pause_reason_to_public_dict
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.workflow.human_input_policy import enrich_human_input_pause_reasons
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
@ -54,6 +52,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.workflow.human_input_forms import load_form_tokens_by_form_id
|
||||
from core.workflow.human_input_policy import HumanInputSurface, enrich_human_input_pause_reasons
|
||||
from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
@ -319,7 +318,7 @@ class WorkflowResponseConverter:
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [pause_reason_to_public_dict(reason) for reason in event.reasons]
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
display_in_ui_by_form_id: dict[str, bool] = {}
|
||||
@ -338,7 +337,15 @@ class WorkflowResponseConverter:
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
definition_payload = {}
|
||||
display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui"))
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session)
|
||||
form_token_by_form_id = load_form_tokens_by_form_id(
|
||||
human_input_form_ids,
|
||||
session=session,
|
||||
surface=(
|
||||
HumanInputSurface.SERVICE_API
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Reconnect paths must preserve the same pause-reason contract as live streams;
|
||||
# otherwise clients see schema drift after resume.
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
@ -20,7 +22,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
response: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
@ -67,7 +69,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
@ -97,7 +99,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
response_chunk: dict[str, JsonValue] = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
|
||||
@ -59,7 +59,7 @@ def stream_topic_events(
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
if terminal_events is None:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
|
||||
@ -63,7 +63,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict[str, object], data))
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
@ -92,9 +92,9 @@ class WorkflowAppGenerateResponseConverter(
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict[str, object], data))
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict[str, object], sub_stream_response.to_ignore_detail_dict()))
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@ -42,6 +42,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
HumanInputRequiredPauseReasonPayload,
|
||||
HumanInputRequiredResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
@ -199,14 +200,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> WorkflowAppPausedBlockingResponse:
|
||||
runtime_state = self._resolve_graph_runtime_state()
|
||||
paused_nodes = list(dict.fromkeys(response.data.node_id for response in human_input_responses))
|
||||
# Graph runtime `start_at` is a perf-counter value, not an epoch timestamp, so
|
||||
# fallback API payloads need a wall-clock source for `created_at`.
|
||||
created_at = int(time.time())
|
||||
reasons = []
|
||||
for response in human_input_responses:
|
||||
reason = response.data.model_dump(mode="json")
|
||||
reason["type"] = "human_input_required"
|
||||
reasons.append(reason)
|
||||
created_at = int(runtime_state.start_at)
|
||||
reasons = [
|
||||
HumanInputRequiredPauseReasonPayload.from_response_data(response.data).model_dump(mode="json")
|
||||
for response in human_input_responses
|
||||
]
|
||||
|
||||
return WorkflowAppPausedBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from graphon.entities import WorkflowStartReason
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
@ -295,6 +296,40 @@ class HumanInputRequiredResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredPauseReasonPayload(BaseModel):
|
||||
"""
|
||||
Public pause-reason payload used by blocking responses when only
|
||||
``human_input_required`` events are available.
|
||||
"""
|
||||
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
@classmethod
|
||||
def from_response_data(cls, data: HumanInputRequiredResponse.Data) -> "HumanInputRequiredPauseReasonPayload":
|
||||
return cls(
|
||||
form_id=data.form_id,
|
||||
node_id=data.node_id,
|
||||
node_title=data.node_title,
|
||||
form_content=data.form_content,
|
||||
inputs=data.inputs,
|
||||
actions=data.actions,
|
||||
display_in_ui=data.display_in_ui,
|
||||
form_token=data.form_token,
|
||||
resolved_default_values=data.resolved_default_values,
|
||||
expiration_time=data.expiration_time,
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
@ -355,7 +390,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -412,7 +447,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
def to_ignore_detail_dict(self) -> dict[str, JsonValue]:
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
@ -774,7 +809,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class ChatbotAppPausedBlockingResponse(AppBlockingResponse):
|
||||
class AdvancedChatPausedBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
ChatbotAppPausedBlockingResponse entity
|
||||
"""
|
||||
@ -793,7 +828,7 @@ class ChatbotAppPausedBlockingResponse(AppBlockingResponse):
|
||||
metadata: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list[Mapping[str, Any]])
|
||||
status: WorkflowExecutionStatus
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Generator # Changed from Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass
|
||||
@ -32,7 +32,7 @@ def get_current_file_access_scope() -> FileAccessScope | None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]:
|
||||
def bind_file_access_scope(scope: FileAccessScope) -> Generator[None, None, None]: # Changed from Iterator[None]
|
||||
token = _current_file_access_scope.set(scope)
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
"""Resolves and returns LLM credentials for a given provider and model.
|
||||
|
||||
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||
Because of that cache, a single instance can return stale credentials after
|
||||
the tenant or provider configuration changes (e.g. API key rotation).
|
||||
|
||||
Do **not** keep one instance for the lifetime of a process or across
|
||||
unrelated invocations. Create a new provider per request, workflow run, or
|
||||
other bounded scope where up-to-date credentials matter.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
self.provider_manager = provider_manager
|
||||
self.credentials_cache = {}
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
if (provider_name, model_name) in self.credentials_cache:
|
||||
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||
return credentials
|
||||
|
||||
|
||||
@ -65,7 +84,8 @@ class DifyModelFactory:
|
||||
provider_manager=create_plugin_provider_manager(
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
),
|
||||
enable_credentials_cache=True,
|
||||
)
|
||||
self.model_manager = model_manager
|
||||
|
||||
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
||||
tenant_id=run_context.tenant_id,
|
||||
user_id=run_context.user_id,
|
||||
)
|
||||
model_manager = ModelManager(provider_manager=provider_manager)
|
||||
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||
|
||||
return (
|
||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||
|
||||
41
api/core/helper/creators.py
Normal file
41
api/core/helper/creators.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
from typing import Any, NotRequired, Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from sqlalchemy import select
|
||||
@ -18,8 +18,6 @@ from core.llm_generator.prompts import (
|
||||
LLM_MODIFY_CODE_SYSTEM,
|
||||
LLM_MODIFY_PROMPT_SYSTEM,
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
)
|
||||
@ -41,6 +39,36 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
completion_params: NotRequired[dict[str, object]]
|
||||
|
||||
|
||||
def _normalize_completion_params(completion_params: dict[str, object]) -> tuple[dict[str, object], list[str]]:
|
||||
"""
|
||||
Normalize raw completion params into invocation parameters and stop sequences.
|
||||
|
||||
This mirrors the app-model access path by separating ``stop`` from provider
|
||||
parameters before invocation, then drops non-positive token limits because
|
||||
some plugin-backed models reject ``0`` after mapping ``max_tokens`` to their
|
||||
provider-specific output-token field.
|
||||
"""
|
||||
normalized_parameters = dict(completion_params)
|
||||
stop_value = normalized_parameters.pop("stop", [])
|
||||
if isinstance(stop_value, list) and all(isinstance(item, str) for item in stop_value):
|
||||
stop = stop_value
|
||||
else:
|
||||
stop = []
|
||||
|
||||
for token_limit_key in ("max_tokens", "max_output_tokens"):
|
||||
token_limit = normalized_parameters.get(token_limit_key)
|
||||
if isinstance(token_limit, int | float) and token_limit <= 0:
|
||||
normalized_parameters.pop(token_limit_key, None)
|
||||
|
||||
return normalized_parameters, stop
|
||||
|
||||
|
||||
class WorkflowServiceInterface(Protocol):
|
||||
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
|
||||
pass
|
||||
@ -123,8 +151,15 @@ class LLMGenerator:
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
def generate_suggested_questions_after_answer(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
histories: str,
|
||||
*,
|
||||
instruction_prompt: str | None = None,
|
||||
model_config: object | None = None,
|
||||
) -> Sequence[str]:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser(instruction_prompt=instruction_prompt)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
@ -133,10 +168,36 @@ class LLMGenerator:
|
||||
|
||||
try:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
configured_model = cast(dict[str, object], model_config) if isinstance(model_config, dict) else {}
|
||||
provider = configured_model.get("provider")
|
||||
model_name = configured_model.get("name")
|
||||
use_configured_model = False
|
||||
|
||||
if isinstance(provider, str) and provider and isinstance(model_name, str) and model_name:
|
||||
try:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
use_configured_model = True
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to use configured suggested-questions model %s/%s, fallback to default model",
|
||||
provider,
|
||||
model_name,
|
||||
exc_info=True,
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
return []
|
||||
|
||||
@ -145,19 +206,29 @@ class LLMGenerator:
|
||||
questions: Sequence[str] = []
|
||||
|
||||
try:
|
||||
configured_completion_params = configured_model.get("completion_params")
|
||||
if use_configured_model and isinstance(configured_completion_params, dict):
|
||||
model_parameters, stop = _normalize_completion_params(configured_completion_params)
|
||||
elif use_configured_model:
|
||||
model_parameters = {}
|
||||
stop = []
|
||||
else:
|
||||
# Default-model generation keeps the built-in suggested-questions tuning.
|
||||
model_parameters = {
|
||||
"max_tokens": 2560,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
stop = []
|
||||
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters={
|
||||
"max_tokens": SUGGESTED_QUESTIONS_MAX_TOKENS,
|
||||
"temperature": SUGGESTED_QUESTIONS_TEMPERATURE,
|
||||
},
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
questions = output_parser.parse(text_content) if text_content else []
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
@ -3,17 +3,28 @@ import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
from core.llm_generator.prompts import DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
def __init__(self, instruction_prompt: str | None = None) -> None:
|
||||
self._instruction_prompt = self._build_instruction_prompt(instruction_prompt)
|
||||
|
||||
@staticmethod
|
||||
def _build_instruction_prompt(instruction_prompt: str | None) -> str:
|
||||
if not instruction_prompt or not instruction_prompt.strip():
|
||||
return DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
|
||||
return f'{instruction_prompt}\nYou must output a JSON array like ["question1", "question2", "question3"].'
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
|
||||
return self._instruction_prompt
|
||||
|
||||
def parse(self, text: str) -> Sequence[str]:
|
||||
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
|
||||
stripped_text = text.strip()
|
||||
action_match = re.search(r"\[.*?\]", stripped_text, re.DOTALL)
|
||||
questions: list[str] = []
|
||||
if action_match is not None:
|
||||
try:
|
||||
@ -23,4 +34,6 @@ class SuggestedQuestionsAfterAnswerOutputParser:
|
||||
else:
|
||||
if isinstance(json_obj, list):
|
||||
questions = [question for question in json_obj if isinstance(question, str)]
|
||||
elif stripped_text:
|
||||
logger.warning("Failed to find suggested questions payload array in text: %r", stripped_text[:200])
|
||||
return questions
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
import os
|
||||
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
|
||||
@ -96,8 +95,8 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
# Default prompt for suggested questions (can be overridden by environment variable)
|
||||
_DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
|
||||
# Default prompt and model parameters for suggested questions.
|
||||
DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keep each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
@ -105,15 +104,6 @@ _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT = (
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
# Environment variable override for suggested questions prompt
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = os.getenv(
|
||||
"SUGGESTED_QUESTIONS_PROMPT", _DEFAULT_SUGGESTED_QUESTIONS_AFTER_ANSWER_PROMPT
|
||||
)
|
||||
|
||||
# Configurable LLM parameters for suggested questions (can be overridden by environment variables)
|
||||
SUGGESTED_QUESTIONS_MAX_TOKENS = int(os.getenv("SUGGESTED_QUESTIONS_MAX_TOKENS", "256"))
|
||||
SUGGESTED_QUESTIONS_TEMPERATURE = float(os.getenv("SUGGESTED_QUESTIONS_TEMPERATURE", "0"))
|
||||
|
||||
GENERATOR_QA_PROMPT = (
|
||||
"<Task> The user will send a long text. Generate a Question and Answer pairs only using the knowledge"
|
||||
" in the long text. Please think step by step."
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -36,11 +37,13 @@ class ModelInstance:
|
||||
Model instance class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
if credentials is None:
|
||||
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.credentials = credentials
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
@ -434,8 +437,30 @@ class ModelInstance:
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, provider_manager: ProviderManager):
|
||||
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||
|
||||
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||
``(tenant_id, provider, model_type, model)`` are stored in
|
||||
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||
API keys or provider settings change, so a manager constructed with
|
||||
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||
process or shared across unrelated work. Prefer a new manager per request,
|
||||
workflow run, or similar bounded scope.
|
||||
|
||||
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||
credential cache is not populated, and each ``get_model_instance`` call
|
||||
loads credentials from the current provider configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
*,
|
||||
enable_credentials_cache: bool = False,
|
||||
) -> None:
|
||||
self._provider_manager = provider_manager
|
||||
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||
self._enable_credentials_cache = enable_credentials_cache
|
||||
|
||||
@classmethod
|
||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||
@ -463,8 +488,19 @@ class ModelManager:
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
model_instance = ModelInstance(provider_model_bundle, model)
|
||||
return model_instance
|
||||
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||
|
||||
if cred_cache_key in self._credentials_cache:
|
||||
return ModelInstance(
|
||||
provider_model_bundle,
|
||||
model,
|
||||
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||
)
|
||||
|
||||
ret = ModelInstance(provider_model_bundle, model)
|
||||
if self._enable_credentials_cache:
|
||||
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||
return ret
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
|
||||
@ -151,6 +151,12 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
|
||||
response = Response(response=body, status=status_code)
|
||||
|
||||
# Replace Flask's default headers (e.g. Content-Type, Content-Length) with the
|
||||
# parsed ones so we faithfully reproduce the original response. Use Headers.add
|
||||
# rather than dict-style assignment so that repeated headers such as Set-Cookie
|
||||
# (and any other multi-valued header per RFC 9110) are preserved instead of
|
||||
# being overwritten.
|
||||
response.headers.clear()
|
||||
for line in lines[1:]:
|
||||
if not line:
|
||||
continue
|
||||
@ -158,6 +164,6 @@ def deserialize_response(raw_data: bytes) -> Response:
|
||||
if ":" not in line_str:
|
||||
continue
|
||||
name, value = line_str.split(":", 1)
|
||||
response.headers[name] = value.strip()
|
||||
response.headers.add(name, value.strip())
|
||||
|
||||
return response
|
||||
|
||||
@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@ -70,12 +70,32 @@ class ProviderManager:
|
||||
Request-bound managers may carry caller identity in that runtime, and the
|
||||
resulting ``ProviderConfiguration`` objects must reuse it for downstream
|
||||
model-type and schema lookups.
|
||||
|
||||
Configuration assembly is cached per manager instance so call chains that
|
||||
share one request-scoped manager can reuse the same provider graph instead
|
||||
of rebuilding it for every lookup. Call ``clear_configurations_cache()``
|
||||
when a long-lived manager needs to observe writes performed within the same
|
||||
instance scope.
|
||||
"""
|
||||
|
||||
decoding_rsa_key: Any | None
|
||||
decoding_cipher_rsa: Any | None
|
||||
_model_runtime: ModelRuntime
|
||||
_configurations_cache: dict[str, ProviderConfigurations]
|
||||
|
||||
def __init__(self, model_runtime: ModelRuntime):
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
self._model_runtime = model_runtime
|
||||
self._configurations_cache = {}
|
||||
|
||||
def clear_configurations_cache(self, tenant_id: str | None = None) -> None:
|
||||
"""Drop assembled provider configurations cached on this manager instance."""
|
||||
if tenant_id is None:
|
||||
self._configurations_cache.clear()
|
||||
return
|
||||
|
||||
self._configurations_cache.pop(tenant_id, None)
|
||||
|
||||
def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
|
||||
"""
|
||||
@ -114,6 +134,10 @@ class ProviderManager:
|
||||
:param tenant_id:
|
||||
:return:
|
||||
"""
|
||||
cached_configurations = self._configurations_cache.get(tenant_id)
|
||||
if cached_configurations is not None:
|
||||
return cached_configurations
|
||||
|
||||
# Get all provider records of the workspace
|
||||
provider_name_to_provider_records_dict = self._get_all_providers(tenant_id)
|
||||
|
||||
@ -273,6 +297,8 @@ class ProviderManager:
|
||||
|
||||
provider_configurations[str(provider_id_entity)] = provider_configuration
|
||||
|
||||
self._configurations_cache[tenant_id] = provider_configurations
|
||||
|
||||
# Return the encapsulated object
|
||||
return provider_configurations
|
||||
|
||||
@ -419,7 +445,7 @@ class ProviderManager:
|
||||
@staticmethod
|
||||
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
|
||||
provider_name_to_provider_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
|
||||
providers = session.scalars(stmt)
|
||||
for provider in providers:
|
||||
@ -436,7 +462,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_records_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
|
||||
provider_models = session.scalars(stmt)
|
||||
for provider_model in provider_models:
|
||||
@ -452,7 +478,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_preferred_provider_type_records_dict = {}
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
|
||||
preferred_provider_types = session.scalars(stmt)
|
||||
provider_name_to_preferred_provider_type_records_dict = {
|
||||
@ -470,7 +496,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_settings_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
|
||||
provider_model_settings = session.scalars(stmt)
|
||||
for provider_model_setting in provider_model_settings:
|
||||
@ -488,7 +514,7 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
provider_name_to_provider_model_credentials_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ProviderModelCredential).where(ProviderModelCredential.tenant_id == tenant_id)
|
||||
provider_model_credentials = session.scalars(stmt)
|
||||
for provider_model_credential in provider_model_credentials:
|
||||
@ -518,7 +544,7 @@ class ProviderManager:
|
||||
return {}
|
||||
|
||||
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
|
||||
provider_load_balancing_configs = session.scalars(stmt)
|
||||
for provider_load_balancing_config in provider_load_balancing_configs:
|
||||
@ -552,7 +578,7 @@ class ProviderManager:
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(
|
||||
@ -582,7 +608,7 @@ class ProviderManager:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
|
||||
@ -139,8 +139,10 @@ class Jieba(BaseKeyword):
|
||||
"__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table},
|
||||
}
|
||||
dataset_keyword_table = self.dataset.dataset_keyword_table
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type
|
||||
keyword_data_source_type = dataset_keyword_table.data_source_type if dataset_keyword_table else "file"
|
||||
if keyword_data_source_type == "database":
|
||||
if dataset_keyword_table is None:
|
||||
return
|
||||
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
|
||||
db.session.commit()
|
||||
else:
|
||||
@ -154,7 +156,8 @@ class Jieba(BaseKeyword):
|
||||
if dataset_keyword_table:
|
||||
keyword_table_dict = dataset_keyword_table.keyword_table_dict
|
||||
if keyword_table_dict:
|
||||
return dict(keyword_table_dict["__data__"]["table"])
|
||||
data: Any = keyword_table_dict["__data__"]
|
||||
return dict(data["table"])
|
||||
else:
|
||||
keyword_data_source_type = dify_config.KEYWORD_DATA_SOURCE_TYPE
|
||||
dataset_keyword_table = DatasetKeywordTable(
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from operator import itemgetter
|
||||
from typing import cast
|
||||
|
||||
@ -80,12 +81,14 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
def extract_tags(self, sentence: str, top_k: int | None = 20, **kwargs):
|
||||
# Basic frequency-based keyword extraction as a fallback when TF-IDF is unavailable.
|
||||
top_k = kwargs.pop("topK", top_k)
|
||||
top_k = cast(int | None, kwargs.pop("topK", top_k))
|
||||
if top_k is None:
|
||||
top_k = 20
|
||||
cut = getattr(jieba, "cut", None)
|
||||
if self._lcut:
|
||||
tokens = self._lcut(sentence)
|
||||
elif callable(cut):
|
||||
tokens = list(cut(sentence))
|
||||
tokens = list(cast(Callable[[str], list[str]], cut)(sentence))
|
||||
else:
|
||||
tokens = re.findall(r"\w+", sentence)
|
||||
|
||||
@ -106,9 +109,9 @@ class JiebaKeywordTableHandler:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = self._tfidf.extract_tags(
|
||||
sentence=text,
|
||||
topK=max_keywords_per_chunk,
|
||||
topK=max_keywords_per_chunk or 10,
|
||||
)
|
||||
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
|
||||
# jieba.analyse.extract_tags returns an untyped list when withFlag is False by default.
|
||||
keywords = cast(list[str], keywords)
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(set(keywords)))
|
||||
|
||||
@ -158,7 +158,7 @@ class RetrievalService:
|
||||
)
|
||||
|
||||
if futures:
|
||||
for future in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
for _ in concurrent.futures.as_completed(futures, timeout=3600):
|
||||
if exceptions:
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
@ -217,10 +217,11 @@ class RetrievalService:
|
||||
"""Deduplicate documents in O(n) while preserving first-seen order.
|
||||
|
||||
Rules:
|
||||
- For provider == "dify" and metadata["doc_id"] exists: keep the doc with the highest
|
||||
metadata["score"] among duplicates; if a later duplicate has no score, ignore it.
|
||||
- For non-dify documents (or dify without doc_id): deduplicate by content key
|
||||
(provider, page_content), keeping the first occurrence.
|
||||
- If metadata["doc_id"] exists (any provider): deduplicate by (provider, doc_id) key;
|
||||
keep the doc with the highest metadata["score"] among duplicates. If a later duplicate
|
||||
has no score, ignore it.
|
||||
- If metadata["doc_id"] is absent: deduplicate by content key (provider, page_content),
|
||||
keeping the first occurrence.
|
||||
"""
|
||||
if not documents:
|
||||
return documents
|
||||
@ -231,11 +232,10 @@ class RetrievalService:
|
||||
order: list[tuple] = []
|
||||
|
||||
for doc in documents:
|
||||
is_dify = doc.provider == "dify"
|
||||
doc_id = (doc.metadata or {}).get("doc_id") if is_dify else None
|
||||
doc_id = (doc.metadata or {}).get("doc_id")
|
||||
|
||||
if is_dify and doc_id:
|
||||
key = ("dify", doc_id)
|
||||
if doc_id:
|
||||
key = (doc.provider or "dify", doc_id)
|
||||
if key not in chosen:
|
||||
chosen[key] = doc
|
||||
order.append(key)
|
||||
@ -551,6 +551,7 @@ class RetrievalService:
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
for i in child_index_nodes:
|
||||
assert i.index_node_id
|
||||
segment_ids.append(i.segment_id)
|
||||
if i.segment_id in child_chunk_map:
|
||||
child_chunk_map[i.segment_id].append(i)
|
||||
|
||||
@ -39,6 +39,58 @@ class AbstractVectorFactory(ABC):
|
||||
return index_struct_dict
|
||||
|
||||
|
||||
class _LazyEmbeddings(Embeddings):
|
||||
"""Lazy proxy that defers materializing the real embedding model.
|
||||
|
||||
Constructing the real embeddings (via ``ModelManager.get_model_instance``)
|
||||
transitively calls ``FeatureService.get_features`` → ``BillingService``
|
||||
HTTP GETs (see ``provider_manager.py``). Cleanup paths
|
||||
(``delete_by_ids`` / ``delete`` / ``text_exists``) do not need embeddings
|
||||
at all, so deferring this until an ``embed_*`` method is actually invoked
|
||||
keeps cleanup tasks resilient to transient billing-API failures and avoids
|
||||
leaving stranded ``document_segments`` / ``child_chunks`` whenever billing
|
||||
hiccups.
|
||||
|
||||
Existing callers that perform create / search operations are unaffected:
|
||||
the first ``embed_*`` call materializes the underlying model and the
|
||||
behavior is identical from that point on.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self._dataset = dataset
|
||||
self._real: Embeddings | None = None
|
||||
|
||||
def _ensure(self) -> Embeddings:
|
||||
if self._real is None:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id)
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
provider=self._dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=self._dataset.embedding_model,
|
||||
)
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
class Vector:
|
||||
def __init__(self, dataset: Dataset, attributes: list | None = None):
|
||||
if attributes is None:
|
||||
@ -60,7 +112,11 @@ class Vector:
|
||||
"original_chunk_id",
|
||||
]
|
||||
self._dataset = dataset
|
||||
self._embeddings = self._get_embeddings()
|
||||
# Use a lazy proxy so cleanup paths (delete_by_ids / delete / text_exists)
|
||||
# never transitively trigger billing API calls during ``Vector(dataset)``
|
||||
# construction. The real embedding model is materialized only when an
|
||||
# ``embed_*`` method is actually invoked (i.e. create / search paths).
|
||||
self._embeddings: Embeddings = _LazyEmbeddings(dataset)
|
||||
self._attributes = attributes
|
||||
self._vector_processor = self._init_vector()
|
||||
|
||||
@ -88,8 +144,20 @@ class Vector:
|
||||
def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
|
||||
return get_vector_factory_class(vector_type)
|
||||
|
||||
@staticmethod
|
||||
def _filter_empty_text_documents(documents: list[Document]) -> list[Document]:
|
||||
filtered_documents = [document for document in documents if document.page_content.strip()]
|
||||
skipped_count = len(documents) - len(filtered_documents)
|
||||
if skipped_count:
|
||||
logger.warning("skip %d empty documents before vector embedding", skipped_count)
|
||||
return filtered_documents
|
||||
|
||||
def create(self, texts: list | None = None, **kwargs):
|
||||
if texts:
|
||||
texts = self._filter_empty_text_documents(texts)
|
||||
if not texts:
|
||||
return
|
||||
|
||||
start = time.time()
|
||||
logger.info("start embedding %s texts %s", len(texts), start)
|
||||
batch_size = 1000
|
||||
@ -147,8 +215,14 @@ class Vector:
|
||||
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
|
||||
|
||||
def add_texts(self, documents: list[Document], **kwargs):
|
||||
documents = self._filter_empty_text_documents(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if kwargs.get("duplicate_check", False):
|
||||
documents = self._filter_duplicate_texts(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
|
||||
self._vector_processor.create(texts=documents, embeddings=embeddings, **kwargs)
|
||||
|
||||
@ -11,6 +11,7 @@ from core.rag.models.document import AttachmentDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.enums import SegmentType
|
||||
|
||||
|
||||
class DatasetDocumentStore:
|
||||
@ -127,6 +128,7 @@ class DatasetDocumentStore:
|
||||
if save_child:
|
||||
if doc.children:
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -137,7 +139,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
type=SegmentType.AUTOMATIC,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
@ -163,6 +165,7 @@ class DatasetDocumentStore:
|
||||
)
|
||||
# add new child chunks
|
||||
for position, child in enumerate(doc.children, start=1):
|
||||
assert self._document_id
|
||||
child_segment = ChildChunk(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
dataset_id=self._dataset.id,
|
||||
@ -173,7 +176,7 @@ class DatasetDocumentStore:
|
||||
index_node_hash=child.metadata.get("doc_hash"),
|
||||
content=child.page_content,
|
||||
word_count=len(child.page_content),
|
||||
type="automatic",
|
||||
type=SegmentType.AUTOMATIC,
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(child_segment)
|
||||
|
||||
@ -94,6 +94,7 @@ class ExtractProcessor:
|
||||
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
|
||||
) -> list[Document]:
|
||||
if extract_setting.datasource_type == DatasourceType.FILE:
|
||||
upload_file = extract_setting.upload_file
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
upload_file = extract_setting.upload_file
|
||||
if not file_path:
|
||||
@ -104,6 +105,7 @@ class ExtractProcessor:
|
||||
storage.download(upload_file.key, file_path)
|
||||
input_file = Path(file_path)
|
||||
file_extension = input_file.suffix.lower()
|
||||
assert upload_file is not None, "upload_file is required"
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
extractor: BaseExtractor | None = None
|
||||
if etl_type == "Unstructured":
|
||||
|
||||
@ -28,10 +28,10 @@ class FunctionCallMultiDatasetRouter:
|
||||
SystemPromptMessage(content="You are a helpful AI assistant."),
|
||||
UserPromptMessage(content=query),
|
||||
]
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
result: LLMResult = model_instance.invoke_llm( # pyright: ignore[reportCallIssue, reportArgumentType]
|
||||
prompt_messages=prompt_messages,
|
||||
tools=dataset_tools,
|
||||
stream=False,
|
||||
stream=False, # pyright: ignore[reportArgumentType]
|
||||
model_parameters={"temperature": 0.2, "top_p": 0.3, "max_tokens": 1500},
|
||||
)
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from collections.abc import Collection
|
||||
from collections.abc import Set as AbstractSet
|
||||
from typing import Any, Literal
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@ -21,8 +21,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def from_encoder[T: EnhanceRecursiveCharacterTextSplitter](
|
||||
cls: type[T],
|
||||
embedding_model_instance: ModelInstance | None,
|
||||
allowed_special: Literal["all"] | set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
) -> T:
|
||||
def _token_encoder(texts: list[str]) -> list[int]:
|
||||
@ -40,6 +40,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
|
||||
return [len(text) for text in texts]
|
||||
|
||||
_ = _token_encoder # kept for future token-length wiring
|
||||
return cls(length_function=_character_encoder, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Collection, Iterable, Sequence, Set
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -187,8 +188,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: str | None = None,
|
||||
allowed_special: Literal["all"] | Set[str] = set(),
|
||||
disallowed_special: Literal["all"] | Collection[str] = "all",
|
||||
allowed_special: Literal["all"] | AbstractSet[str] = frozenset(),
|
||||
disallowed_special: Literal["all"] | AbstractSet[str] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
@ -207,8 +208,8 @@ class TokenTextSplitter(TextSplitter):
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
self._allowed_special: Literal["all"] | AbstractSet[str] = allowed_special
|
||||
self._disallowed_special: Literal["all"] | AbstractSet[str] = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
def _encode(_text: str) -> list[int]:
|
||||
|
||||
@ -1078,6 +1078,13 @@ class ToolManager:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
if variable_pool:
|
||||
config = tool_configurations.get(parameter.name, {})
|
||||
|
||||
selector_value = cls._extract_runtime_selector_value(parameter, config)
|
||||
if selector_value is not None:
|
||||
# Selector parameters carry structured dictionaries, not scalar ToolInput values.
|
||||
runtime_parameters[parameter.name] = selector_value
|
||||
continue
|
||||
|
||||
if not (config and isinstance(config, dict) and config.get("value") is not None):
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
@ -1105,5 +1112,39 @@ class ToolManager:
|
||||
runtime_parameters[parameter.name] = value
|
||||
return runtime_parameters
|
||||
|
||||
@classmethod
|
||||
def _extract_runtime_selector_value(cls, parameter: ToolParameter, config: Any) -> dict[str, Any] | None:
|
||||
if parameter.type not in {
|
||||
ToolParameter.ToolParameterType.MODEL_SELECTOR,
|
||||
ToolParameter.ToolParameterType.APP_SELECTOR,
|
||||
}:
|
||||
return None
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
input_value = config.get("value")
|
||||
if isinstance(input_value, dict) and cls._is_selector_value(parameter, input_value):
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(input_value))
|
||||
|
||||
if cls._is_selector_value(parameter, config):
|
||||
selector_value = dict(config)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return cast("dict[str, Any]", parameter.init_frontend_parameter(selector_value))
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _is_selector_value(cls, parameter: ToolParameter, value: Mapping[str, Any]) -> bool:
|
||||
if parameter.type == ToolParameter.ToolParameterType.MODEL_SELECTOR:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
)
|
||||
if parameter.type == ToolParameter.ToolParameterType.APP_SELECTOR:
|
||||
return isinstance(value.get("app_id"), str)
|
||||
return False
|
||||
|
||||
|
||||
ToolManager.load_hardcoded_providers_cache()
|
||||
|
||||
@ -14,23 +14,23 @@ from configs import dify_config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
@ -43,19 +43,19 @@ class SystemOAuthEncrypter:
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
params: Parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -66,7 +66,7 @@ class SystemOAuthEncrypter:
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
@ -76,20 +76,20 @@ class SystemOAuthEncrypter:
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
@ -118,70 +118,70 @@ class SystemOAuthEncrypter:
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: str | None = None) -> SystemOAuthEncrypter:
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: SystemOAuthEncrypter | None = None
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
params: Parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@ -105,7 +105,7 @@ class Article:
|
||||
|
||||
|
||||
def extract_using_readabilipy(html: str):
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True)
|
||||
json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=False)
|
||||
article = Article(
|
||||
title=json_article.get("title") or "",
|
||||
author=json_article.get("byline") or "",
|
||||
|
||||
@ -272,6 +272,14 @@ def _adapt_tool_node_data_for_graph(node_data: Mapping[str, Any]) -> dict[str, A
|
||||
normalized_tool_configurations[name] = value
|
||||
continue
|
||||
|
||||
selector_value = _extract_selector_configuration(value)
|
||||
if selector_value is not None:
|
||||
# Model/app selectors are dictionaries even when they come through the legacy tool configuration path.
|
||||
# Move them to tool_parameters so graph validation does not flatten them as primitive constants.
|
||||
found_legacy_tool_inputs = True
|
||||
normalized_tool_parameters.setdefault(name, {"type": "constant", "value": selector_value})
|
||||
continue
|
||||
|
||||
input_type = value.get("type")
|
||||
input_value = value.get("value")
|
||||
if input_type not in {"mixed", "variable", "constant"}:
|
||||
@ -310,6 +318,28 @@ def _flatten_legacy_tool_configuration_value(*, input_type: Any, input_value: An
|
||||
return None
|
||||
|
||||
|
||||
def _extract_selector_configuration(value: Mapping[str, Any]) -> dict[str, Any] | None:
|
||||
input_value = value.get("value")
|
||||
if isinstance(input_value, Mapping) and _is_selector_configuration(input_value):
|
||||
return dict(input_value)
|
||||
|
||||
if _is_selector_configuration(value):
|
||||
selector_value = dict(value)
|
||||
selector_value.pop("type", None)
|
||||
selector_value.pop("value", None)
|
||||
return selector_value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_selector_configuration(value: Mapping[str, Any]) -> bool:
|
||||
return (
|
||||
isinstance(value.get("provider"), str)
|
||||
and isinstance(value.get("model"), str)
|
||||
and isinstance(value.get("model_type"), str)
|
||||
) or isinstance(value.get("app_id"), str)
|
||||
|
||||
|
||||
def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]:
|
||||
normalized = dict(recipients)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from collections.abc import Sequence
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.human_input_policy import get_preferred_form_token
|
||||
from core.workflow.human_input_policy import HumanInputSurface, get_preferred_form_token
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
|
||||
@ -21,6 +21,7 @@ def load_form_tokens_by_form_id(
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
session: Session | None = None,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Load the preferred access token for each human input form."""
|
||||
unique_form_ids = list(dict.fromkeys(form_ids))
|
||||
@ -28,23 +29,43 @@ def load_form_tokens_by_form_id(
|
||||
return {}
|
||||
|
||||
if session is not None:
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(session, unique_form_ids, surface=surface)
|
||||
|
||||
with Session(bind=db.engine, expire_on_commit=False) as new_session:
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids)
|
||||
return _load_form_tokens_by_form_id(new_session, unique_form_ids, surface=surface)
|
||||
|
||||
|
||||
def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]:
|
||||
def _load_form_tokens_by_form_id(
|
||||
session: Session,
|
||||
form_ids: Sequence[str],
|
||||
*,
|
||||
surface: HumanInputSurface | None = None,
|
||||
) -> dict[str, str]:
|
||||
recipients_by_form_id: dict[str, list[tuple[RecipientType, str]]] = {}
|
||||
stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(stmt):
|
||||
if not recipient.access_token:
|
||||
continue
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append((recipient.recipient_type, recipient.access_token))
|
||||
recipients_by_form_id.setdefault(recipient.form_id, []).append(
|
||||
(recipient.recipient_type, recipient.access_token)
|
||||
)
|
||||
|
||||
tokens_by_form_id: dict[str, str] = {}
|
||||
for form_id, recipients in recipients_by_form_id.items():
|
||||
token = get_preferred_form_token(recipients)
|
||||
token = _get_surface_form_token(recipients, surface=surface)
|
||||
if token is not None:
|
||||
tokens_by_form_id[form_id] = token
|
||||
return tokens_by_form_id
|
||||
|
||||
|
||||
def _get_surface_form_token(
|
||||
recipients: Sequence[tuple[RecipientType, str]],
|
||||
*,
|
||||
surface: HumanInputSurface | None,
|
||||
) -> str | None:
|
||||
if surface == HumanInputSurface.SERVICE_API:
|
||||
for recipient_type, token in recipients:
|
||||
if recipient_type == RecipientType.STANDALONE_WEB_APP and token:
|
||||
return token
|
||||
|
||||
return get_preferred_form_token(recipients)
|
||||
|
||||
@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from graphon.entities.pause_reason import PauseReasonType
|
||||
from models.human_input import RecipientType
|
||||
|
||||
|
||||
@ -61,7 +62,7 @@ def enrich_human_input_pause_reasons(
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for reason in reasons:
|
||||
updated = dict(reason)
|
||||
if updated.get("type") == "human_input_required":
|
||||
if updated.get("TYPE") == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_id = updated.get("form_id")
|
||||
if isinstance(form_id, str):
|
||||
updated["form_token"] = form_tokens_by_form_id.get(form_id)
|
||||
|
||||
@ -365,7 +365,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
(including pydantic ValidationError, which subclasses ValueError),
|
||||
if node type is unknown, or if no implementation exists for the resolved version
|
||||
"""
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapt_node_config_for_graph(node_config))
|
||||
adapted_node_config = adapt_node_config_for_graph(node_config)
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(adapted_node_config)
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
@ -373,6 +374,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
# Re-validate using the resolved node class so workflow-local node schemas
|
||||
# stay explicit and constructors receive the concrete typed payload.
|
||||
resolved_node_data = self._validate_resolved_node_data(node_class, node_data)
|
||||
config_for_node_init: BaseNodeData | dict[str, Any]
|
||||
if isinstance(resolved_node_data, BaseNodeData):
|
||||
config_for_node_init = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
else:
|
||||
config_for_node_init = resolved_node_data
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
BuiltinNodeTypes.CODE: lambda: {
|
||||
@ -442,7 +448,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
node_id=node_id,
|
||||
config=resolved_node_data,
|
||||
config=config_for_node_init,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
**node_init_kwargs,
|
||||
@ -474,10 +480,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
include_retriever_attachment_loader: bool,
|
||||
include_jinja2_template_renderer: bool,
|
||||
) -> dict[str, object]:
|
||||
validated_node_data = cast(
|
||||
LLMCompatibleNodeData,
|
||||
self._validate_resolved_node_data(node_class=node_class, node_data=node_data),
|
||||
)
|
||||
validated_node_data = cast(LLMCompatibleNodeData, node_data)
|
||||
model_instance = self._build_model_instance_for_llm_node(validated_node_data)
|
||||
node_init_kwargs: dict[str, object] = {
|
||||
"credentials_provider": self._llm_credentials_provider,
|
||||
|
||||
@ -501,11 +501,15 @@ class DifyToolNodeRuntime(ToolNodeRuntimeProtocol):
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec:
|
||||
tool_configurations = dict(node_data.tool_configurations)
|
||||
tool_configurations.update(
|
||||
{name: tool_input.model_dump(mode="python") for name, tool_input in node_data.tool_parameters.items()}
|
||||
)
|
||||
return _WorkflowToolRuntimeSpec(
|
||||
provider_type=CoreToolProviderType(node_data.provider_type.value),
|
||||
provider_id=node_data.provider_id,
|
||||
tool_name=node_data.tool_name,
|
||||
tool_configurations=dict(node_data.tool_configurations),
|
||||
tool_configurations=tool_configurations,
|
||||
credential_id=node_data.credential_id,
|
||||
)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user