mirror of
https://github.com/langgenius/dify.git
synced 2026-04-20 02:37:20 +08:00
Compare commits
152 Commits
feat/marke
...
feat/per-d
| Author | SHA1 | Date | |
|---|---|---|---|
| e93fdddb20 | |||
| ccfc8c6f15 | |||
| 4fb3fab82d | |||
| 3cea0dfb07 | |||
| 0d6db3a3f3 | |||
| 3d5a81bd30 | |||
| 208604a3a8 | |||
| 63bfba0bdb | |||
| 9948a51b14 | |||
| 0e0bb3582f | |||
| 546062d2cd | |||
| aad0b3c157 | |||
| 4d4265f531 | |||
| e138523123 | |||
| a65e1f71b4 | |||
| 909c062ee1 | |||
| f5322e45fc | |||
| 017f09f1e9 | |||
| 0ba66ab155 | |||
| 5cd267d755 | |||
| d30946dabf | |||
| b0e524213e | |||
| b1adb5652e | |||
| c825d5dcf6 | |||
| 2127d5850f | |||
| ae9fcc2969 | |||
| 624db69f12 | |||
| 80a7843f45 | |||
| cb55176612 | |||
| 5aa2524d33 | |||
| 2575a3a3ab | |||
| f8f7b0ec1a | |||
| d2ee486900 | |||
| c44ddd9831 | |||
| e645cbd8f8 | |||
| 485fc2c416 | |||
| f09be969bb | |||
| 597a0b4d9f | |||
| 779cce3c61 | |||
| b5d9a71cf9 | |||
| c2af415450 | |||
| 89ce61cfea | |||
| 05c5327f47 | |||
| 3891c0a255 | |||
| 63b1d0c1ea | |||
| 75ed38fb3d | |||
| 63db9a7a2f | |||
| 19c80f0f0e | |||
| c5a0bde3ec | |||
| 1261e5e5e8 | |||
| e2ecd68556 | |||
| bceb0eee9b | |||
| 173e818a62 | |||
| 84d8940dbf | |||
| 3e995e6a6d | |||
| 459c36f21b | |||
| 72adb5468c | |||
| 1194957fde | |||
| 68bd29eda2 | |||
| f67a811f7f | |||
| b9c122e7f4 | |||
| 396b39dff9 | |||
| ac8bd12609 | |||
| b55bef4438 | |||
| 2f9667de76 | |||
| a7b6307d32 | |||
| 2883ad6764 | |||
| 0feff5b048 | |||
| 0bce6b35b4 | |||
| 89e23456f0 | |||
| a39173c227 | |||
| 12e93d374f | |||
| 922f9242e4 | |||
| 7fc0a791a2 | |||
| 8d37116fec | |||
| 4b500f988d | |||
| 5ad906ea6a | |||
| 5b862a43e0 | |||
| 1e5cd69205 | |||
| 9081c46565 | |||
| 40b252be8c | |||
| ba1357038a | |||
| 46d1f4c338 | |||
| 9c880dd650 | |||
| 01ba0e050f | |||
| ccc4aae94e | |||
| 01242e13d7 | |||
| 938ee27e42 | |||
| a101f72153 | |||
| 40642433d8 | |||
| 8979181d5e | |||
| c17c6b5c35 | |||
| e83a4090ac | |||
| b71b9f80b9 | |||
| ee87289917 | |||
| 5ad8c3e249 | |||
| 8b992513b8 | |||
| eca0cdc7a9 | |||
| 779e6b8e0b | |||
| c2428361c4 | |||
| 68e4d13f36 | |||
| cb9f4bb100 | |||
| 8a398f3105 | |||
| 0f051d5886 | |||
| e85d9a0d72 | |||
| 06dde4f503 | |||
| 83d4176785 | |||
| c94951b2f8 | |||
| a9cf8f6c5d | |||
| 64ddec0d67 | |||
| da3b0caf5e | |||
| 4fedd43af5 | |||
| a263f28e19 | |||
| d53862f135 | |||
| 608958de1c | |||
| 7eb632eb34 | |||
| 33d4fd357c | |||
| e55bd61c17 | |||
| f2fc213d52 | |||
| f814579ed2 | |||
| 71d299d0d3 | |||
| e178451d04 | |||
| 9a6222f245 | |||
| affe5ed30b | |||
| 4cc5401d7e | |||
| 36e840cd87 | |||
| 985b41c40b | |||
| 2e29ac2829 | |||
| dbfb474eab | |||
| d243de26ec | |||
| 894826771a | |||
| a3386da5d6 | |||
| 318a3d0308 | |||
| 5bafb163cc | |||
| 52b1bc5b09 | |||
| 1873b22e96 | |||
| 9a8c853a2e | |||
| e54383d0fe | |||
| 43c48ba4d7 | |||
| 8f9dbf269e | |||
| cb9ee5903a | |||
| cd406d2794 | |||
| 993a301468 | |||
| 399d3f8da5 | |||
| f9d9ad7a38 | |||
| 2d29345f26 | |||
| 725f9e3dc4 | |||
| 4e1d060439 | |||
| 391007d02e | |||
| e41965061c | |||
| 2b9eb06555 | |||
| 31f7752ba9 |
@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
|
||||
onSuccess: () => toast.success('...'),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
@ -114,10 +114,7 @@ try {
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -1,3 +1,10 @@
|
||||
web:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'web/**'
|
||||
- any-glob-to-any-file:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@ -20,4 +20,4 @@
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods
|
||||
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
|
||||
|
||||
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
@ -0,0 +1,82 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
outputPath,
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@ -39,9 +39,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
|
||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@ -65,7 +65,7 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
@ -130,7 +130,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@ -8,9 +8,11 @@ on:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
|
||||
4
.github/workflows/main-ci.yml
vendored
4
.github/workflows/main-ci.yml
vendored
@ -65,9 +65,11 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@ -77,9 +79,11 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
|
||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -77,9 +77,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
@ -149,7 +151,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
||||
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@ -9,6 +9,7 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
103
.github/workflows/translate-i18n-claude.yml
vendored
103
.github/workflows/translate-i18n-claude.yml
vendored
@ -68,89 +68,7 @@ jobs:
|
||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||
|
||||
generate_changes_json() {
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
}
|
||||
|
||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||
@ -240,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
|
||||
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@ -270,7 +188,7 @@ jobs:
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `pnpm`.
|
||||
- Use Bash only for `vp`.
|
||||
- Do not use Bash for `git`, `gh`, or branch management.
|
||||
|
||||
Required execution plan:
|
||||
@ -292,7 +210,7 @@ jobs:
|
||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
@ -300,19 +218,19 @@ jobs:
|
||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
6. Verify only the edited files.
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
|
||||
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Stop after the scoped locale files are updated and verification passes.
|
||||
- Do not create branches, commits, or pull requests.
|
||||
claude_args: |
|
||||
--max-turns 120
|
||||
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
|
||||
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
|
||||
|
||||
- name: Prepare branch metadata
|
||||
id: pr_meta
|
||||
@ -354,6 +272,7 @@ jobs:
|
||||
- name: Create or update translation PR
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||
@ -402,8 +321,8 @@ jobs:
|
||||
'',
|
||||
'## Verification',
|
||||
'',
|
||||
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
|
||||
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
|
||||
'',
|
||||
'## Notes',
|
||||
'',
|
||||
|
||||
83
.github/workflows/trigger-i18n-sync.yml
vendored
83
.github/workflows/trigger-i18n-sync.yml
vendored
@ -42,88 +42,7 @@ jobs:
|
||||
fi
|
||||
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = readCurrentJson(fileStem) || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: readCurrentJson(fileStem) === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -212,7 +212,8 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
/node_modules
|
||||
node_modules
|
||||
.vite-hooks/_
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
34
web/.husky/pre-commit → .vite-hooks/pre-commit
Normal file → Executable file
34
web/.husky/pre-commit → .vite-hooks/pre-commit
Normal file → Executable file
@ -77,42 +77,22 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
lint-staged
|
||||
vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! pnpm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
|
||||
if ! npm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running unit tests check"
|
||||
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
|
||||
|
||||
if [ -n "$modified_files" ]; then
|
||||
for file in $modified_files; do
|
||||
test_file="${file%.*}.spec.ts"
|
||||
echo "Checking for test file: $test_file"
|
||||
|
||||
# check if the test file exists
|
||||
if [ -f "../$test_file" ]; then
|
||||
echo "Detected changes in $file, running corresponding unit tests..."
|
||||
pnpm run test "../$test_file"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Unit tests failed. Please fix the errors before committing."
|
||||
exit 1
|
||||
fi
|
||||
echo "Unit tests for $file passed."
|
||||
else
|
||||
echo "Warning: $file does not have a corresponding test file."
|
||||
fi
|
||||
|
||||
done
|
||||
echo "All unit tests for modified web/utils files have passed."
|
||||
echo "Running knip"
|
||||
if ! npm run knip; then
|
||||
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ../
|
||||
@ -115,12 +115,6 @@ ignore = [
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.pyflakes]
|
||||
allowed-unused-imports = [
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
|
||||
18
api/celery_healthcheck.py
Normal file
18
api/celery_healthcheck.py
Normal file
@ -0,0 +1,18 @@
|
||||
# This module provides a lightweight Celery instance for use in Docker health checks.
|
||||
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
|
||||
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
|
||||
# Using this module keeps the health check fast and low-cost.
|
||||
from celery import Celery
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
|
||||
|
||||
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
|
||||
|
||||
broker_transport_options = get_celery_broker_transport_options()
|
||||
if broker_transport_options:
|
||||
celery.conf.update(broker_transport_options=broker_transport_options)
|
||||
|
||||
ssl_options = get_celery_ssl_options()
|
||||
if ssl_options:
|
||||
celery.conf.update(broker_use_ssl=ssl_options)
|
||||
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
@ -503,7 +503,19 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
|
||||
return [row[0] for row in result]
|
||||
|
||||
|
||||
def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
class _AppOrphanCounts(TypedDict):
|
||||
variables: int
|
||||
files: int
|
||||
|
||||
|
||||
class OrphanedDraftVariableStatsDict(TypedDict):
|
||||
total_orphaned_variables: int
|
||||
total_orphaned_files: int
|
||||
orphaned_app_count: int
|
||||
orphaned_by_app: dict[str, _AppOrphanCounts]
|
||||
|
||||
|
||||
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
|
||||
"""
|
||||
Count orphaned draft variables by app, including associated file counts.
|
||||
|
||||
@ -526,7 +538,7 @@ def _count_orphaned_draft_variables() -> dict[str, Any]:
|
||||
|
||||
with db.engine.connect() as conn:
|
||||
result = conn.execute(sa.text(variables_query))
|
||||
orphaned_by_app = {}
|
||||
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
|
||||
total_files = 0
|
||||
|
||||
for row in result:
|
||||
|
||||
@ -134,6 +134,26 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DIFY_DB_USER: str | None = Field(
|
||||
description="Override for DB_USERNAME. Takes precedence when set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
DIFY_DB_PASS: str | None = Field(
|
||||
description="Override for DB_PASSWORD. Takes precedence when set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def effective_db_username(self) -> str:
|
||||
return self.DIFY_DB_USER if self.DIFY_DB_USER is not None else self.DB_USERNAME
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def effective_db_password(self) -> str:
|
||||
return self.DIFY_DB_PASS if self.DIFY_DB_PASS is not None else self.DB_PASSWORD
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description="Name of the database to connect to.",
|
||||
default="dify",
|
||||
@ -163,7 +183,7 @@ class DatabaseConfig(BaseSettings):
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (
|
||||
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{quote_plus(self.effective_db_username)}:{quote_plus(self.effective_db_password)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}"
|
||||
)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, TypeVar, final, runtime_checkable
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -188,8 +188,6 @@ class ExecutionContextBuilder:
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ContextProviderNotFoundError(KeyError):
|
||||
"""Raised when a tenant-scoped context provider is missing."""
|
||||
|
||||
@ -1,7 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
@ -11,7 +8,7 @@ class HiddenValue:
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
class RecyclableContextVar[T]:
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
63
api/controllers/common/controller_schemas.py
Normal file
63
api/controllers/common/controller_schemas.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
|
||||
# --- Conversation schemas ---
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
# --- Message schemas ---
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
# --- Saved message schemas ---
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
# --- Workflow schemas ---
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
type JSONObject = dict[str, Any]
|
||||
|
||||
|
||||
class SystemParameters(BaseModel):
|
||||
|
||||
@ -2,7 +2,7 @@ import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -18,10 +18,7 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from services.billing_service import BillingService, LangContentDict
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -72,9 +69,9 @@ console_ns.schema_model(
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
@ -332,7 +329,7 @@ class UpsertNotificationApi(Resource):
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest
|
||||
@ -26,9 +26,11 @@ from controllers.console.wraps import (
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.trigger.constants import TRIGGER_NODE_TYPES
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, DatasetPermissionEnum, Workflow
|
||||
from models.model import IconType
|
||||
@ -41,10 +43,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
NotionIcon,
|
||||
NotionInfo,
|
||||
NotionPage,
|
||||
PreProcessingRule,
|
||||
RerankingModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
WebsiteInfo,
|
||||
WeightKeywordSetting,
|
||||
WeightModel,
|
||||
@ -152,17 +151,7 @@ class AppTracePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
JSONValue: TypeAlias = Any
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
serialize_by_alias=True,
|
||||
protected_namespaces=(),
|
||||
)
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
|
||||
@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
def validate_node_id(node_id: str) -> None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
|
||||
@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
.where(
|
||||
WorkflowWebhookTrigger.app_id == app_model.id,
|
||||
WorkflowWebhookTrigger.node_id == node_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not webhook_trigger:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
from typing import overload
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -9,11 +9,6 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -28,10 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
@ -69,10 +84,30 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -20,35 +20,18 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from libs.password import hash_password
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr = Field(...)
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
|
||||
class ForgotPasswordEmailResponse(BaseModel):
|
||||
result: str = Field(description="Operation result")
|
||||
data: str | None = Field(default=None, description="Reset token")
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@ -42,8 +40,9 @@ from libs.token import (
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
@ -51,9 +50,7 @@ from services.feature_service import FeatureService
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr = Field(..., description="Email address")
|
||||
password: str = Field(..., description="Password")
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
remember_me: bool = Field(default=False, description="Remember me flag")
|
||||
invite_token: str | None = Field(default=None, description="Invitation token")
|
||||
|
||||
@ -101,7 +98,7 @@ class LoginApi(Resource):
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
invitation_data: dict[str, Any] | None = None
|
||||
invitation_data: InvitationDetailDict | None = None
|
||||
if invite_token:
|
||||
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
|
||||
if invitation_data is None:
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask.typing import ResponseReturnValue
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
def oauth_server_client_id_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
def oauth_server_access_token_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
|
||||
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(
|
||||
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
|
||||
) -> R | ResponseReturnValue:
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
|
||||
@ -158,10 +158,11 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
|
||||
@ -173,8 +173,11 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -86,8 +87,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
template = session.scalar(
|
||||
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -70,7 +71,7 @@ def _api_prerequisite(f):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -593,17 +593,15 @@ class PublishedRagPipelineApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
session.add(pipeline)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -9,13 +8,10 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Pipeline
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
|
||||
@ -2,10 +2,10 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@ -32,14 +32,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(console_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
|
||||
pinned: bool | None = None
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
@ -44,17 +44,6 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class MoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
|
||||
|
||||
@ -1,28 +1,18 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@ -34,12 +33,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[P, R]):
|
||||
def trial_feature_enable[**P, R](view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[P, R]):
|
||||
def explore_banner_enabled[**P, R](view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
@ -11,6 +13,21 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
title_pic_url: str
|
||||
|
||||
|
||||
class NotificationResponseDict(TypedDict):
|
||||
should_show: bool
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
@ -45,28 +62,30 @@ class NotificationApi(Resource):
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
response: NotificationResponseDict
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
response = {"should_show": False, "notifications": []}
|
||||
return response, 200
|
||||
|
||||
lang = current_user.interface_language or _FALLBACK_LANG
|
||||
|
||||
notifications = []
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
notifications.append(item)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
response = {"should_show": bool(notifications), "notifications": notifications}
|
||||
return response, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
|
||||
@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.tag_service import TagService
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
TagBindingCreatePayload,
|
||||
TagBindingDeletePayload,
|
||||
TagService,
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
dataset_tag_fields = {
|
||||
"id": fields.String,
|
||||
@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
type: TagType = Field(description="Tag type")
|
||||
|
||||
|
||||
class TagListQueryParam(BaseModel):
|
||||
@ -82,7 +89,7 @@ class TagListApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@ -9,28 +9,25 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
permission = session.scalar(
|
||||
select(TenantPluginPermission)
|
||||
.where(
|
||||
TenantPluginPermission.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not permission:
|
||||
|
||||
@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -240,8 +240,10 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
|
||||
custom_config_dict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand,
|
||||
custom_config_dict: TenantCustomConfigDict = {
|
||||
"remove_webapp_brand": args.remove_webapp_brand
|
||||
if args.remove_webapp_brand is not None
|
||||
else tenant.custom_config_dict.get("remove_webapp_brand", False),
|
||||
"replace_webapp_logo": args.replace_webapp_logo
|
||||
if args.replace_webapp_logo is not None
|
||||
else tenant.custom_config_dict.get("replace_webapp_logo"),
|
||||
|
||||
@ -4,7 +4,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@ -25,9 +24,6 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Field names for decryption
|
||||
FIELD_NAME_PASSWORD = "password"
|
||||
FIELD_NAME_CODE = "code"
|
||||
@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check account initialization
|
||||
@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_register_enabled(view: Callable[P, R]):
|
||||
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def edit_permission_required(f: Callable[P, R]):
|
||||
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Rate limiting decorator for annotation import operations.
|
||||
|
||||
@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def annotation_import_concurrency_limit(view: Callable[P, R]):
|
||||
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Concurrency control decorator for annotation import operations.
|
||||
|
||||
@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
|
||||
payload[field_name] = decoded_value
|
||||
|
||||
|
||||
def decrypt_password_field(view: Callable[P, R]):
|
||||
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to decrypt password field in request payload.
|
||||
|
||||
@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def decrypt_code_field(view: Callable[P, R]):
|
||||
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to decrypt verification code field in request payload.
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -64,7 +64,6 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -1,21 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
@ -33,7 +29,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
@ -56,7 +52,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.flush()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view_func: Callable[P, R]):
|
||||
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
user_id = payload.user_id
|
||||
@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
|
||||
@ -3,10 +3,7 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@ -14,9 +11,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
@ -80,11 +81,11 @@ class MCPAppApi(Resource):
|
||||
|
||||
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
|
||||
"""Get and validate MCP server and app in one query session"""
|
||||
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
|
||||
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
|
||||
if not mcp_server:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
|
||||
|
||||
app = session.query(App).where(App.id == mcp_server.app_id).first()
|
||||
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
|
||||
if not app:
|
||||
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
|
||||
|
||||
@ -190,12 +191,12 @@ class MCPAppApi(Resource):
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||
"""Get end user - manages its own database session"""
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return (
|
||||
session.query(EndUser)
|
||||
return session.scalar(
|
||||
select(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
.where(EndUser.session_id == mcp_server_id)
|
||||
.where(EndUser.type == "mcp")
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
def _create_end_user(
|
||||
|
||||
@ -2,11 +2,12 @@ from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
@ -116,7 +105,7 @@ class ConversationApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
@ -27,17 +26,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -8,9 +8,10 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
@ -314,7 +313,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.enums import TagType
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
TagBindingCreatePayload,
|
||||
TagBindingDeletePayload,
|
||||
TagService,
|
||||
UpdateTagPayload,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
params = {"name": payload.name, "type": "knowledge"}
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(params, tag_id)
|
||||
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
TagService.save_tag_binding(
|
||||
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
TagService.delete_tag_binding(
|
||||
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from controllers.service_api.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.rag.entities import PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
@ -40,11 +41,8 @@ from models.enums import SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
KnowledgeConfig,
|
||||
PreProcessingRule,
|
||||
ProcessRule,
|
||||
RetrievalModel,
|
||||
Rule,
|
||||
Segmentation,
|
||||
)
|
||||
from services.file_service import FileService
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
@ -4,13 +4,23 @@ Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
class UploadFileDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str | None
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
||||
from typing import cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@ -23,10 +24,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -46,16 +43,16 @@ class FetchUserArg(BaseModel):
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(
|
||||
def validate_app_token[**P, R](
|
||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_app_token(
|
||||
def validate_app_token[**P, R](
|
||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@ -136,7 +133,10 @@ def validate_app_token(
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_resource_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
@ -166,7 +166,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@ -188,7 +191,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@ -225,99 +231,73 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
||||
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
positional_parameters = [
|
||||
parameter
|
||||
for parameter in inspect.signature(view).parameters.values()
|
||||
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
]
|
||||
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: object, **kwargs: object) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
||||
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
if not dataset_id and args:
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
def validate_dataset_token(
|
||||
view: Callable[Concatenate[T, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from class method args")
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
return decorated
|
||||
if expects_bound_instance:
|
||||
if not args:
|
||||
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
return decorated
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
webhook_data: RawWebhookDataDict
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
@ -3,10 +3,11 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
@ -34,12 +35,7 @@ from services.errors.audio import (
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@ -19,33 +18,15 @@ from controllers.console.error import EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.password import hash_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
|
||||
|
||||
class ForgotPasswordCheckPayload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str
|
||||
token: str = Field(min_length=1)
|
||||
|
||||
|
||||
class ForgotPasswordResetPayload(BaseModel):
|
||||
token: str = Field(min_length=1)
|
||||
new_password: str
|
||||
password_confirm: str
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
from services.entities.auth_entities import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)
|
||||
|
||||
|
||||
@ -29,13 +29,11 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class LoginPayload(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
|
||||
@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@ -1,27 +1,17 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
|
||||
def validate_jwt_token[**P, R](
|
||||
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
|
||||
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
app_model, end_user = decode_jwt_token()
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
|
||||
@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
||||
return decorator
|
||||
|
||||
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
|
||||
system_features = FeatureService.get_system_features()
|
||||
if not app_code:
|
||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||
|
||||
@ -79,21 +79,18 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
|
||||
content = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
content += f"Final Answer: {unit.agent_response}"
|
||||
else:
|
||||
assert isinstance(assistant_message.content, str)
|
||||
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||
content += f"Thought: {unit.thought}\n\n"
|
||||
if unit.action_str:
|
||||
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||
content += f"Action: {unit.action_str}\n\n"
|
||||
if unit.observation:
|
||||
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||
content += f"Observation: {unit.observation}\n\n"
|
||||
|
||||
assistant_messages = [assistant_message]
|
||||
assistant_messages = [AssistantPromptMessage(content=content)]
|
||||
|
||||
# query messages
|
||||
query_messages = self._organize_user_query(self._query, [])
|
||||
|
||||
@ -5,6 +5,10 @@ from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
class FeatureToggleDict(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SystemParametersDict(TypedDict):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
@ -16,12 +20,12 @@ class SystemParametersDict(TypedDict):
|
||||
class AppParametersDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: dict[str, Any]
|
||||
speech_to_text: dict[str, Any]
|
||||
text_to_speech: dict[str, Any]
|
||||
retriever_resource: dict[str, Any]
|
||||
annotation_reply: dict[str, Any]
|
||||
more_like_this: dict[str, Any]
|
||||
suggested_questions_after_answer: FeatureToggleDict
|
||||
speech_to_text: FeatureToggleDict
|
||||
text_to_speech: FeatureToggleDict
|
||||
retriever_resource: FeatureToggleDict
|
||||
annotation_reply: FeatureToggleDict
|
||||
more_like_this: FeatureToggleDict
|
||||
user_input_form: list[dict[str, Any]]
|
||||
sensitive_word_avoidance: dict[str, Any]
|
||||
file_upload: dict[str, Any]
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -9,6 +8,7 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
|
||||
from core.rag.entities import MetadataFilteringCondition
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
@ -111,31 +111,6 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
SupportedComparisonOperator = Literal[
|
||||
# for string or array
|
||||
"contains",
|
||||
"not contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"is not",
|
||||
"empty",
|
||||
"not empty",
|
||||
"in",
|
||||
"not in",
|
||||
# for number
|
||||
"=",
|
||||
"≠",
|
||||
">",
|
||||
"<",
|
||||
"≥",
|
||||
"≤",
|
||||
# for time
|
||||
"before",
|
||||
"after",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
@ -143,25 +118,6 @@ class ModelConfig(BaseModel):
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition detail
|
||||
"""
|
||||
|
||||
name: str
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None | int | float = None
|
||||
|
||||
|
||||
class MetadataFilteringCondition(BaseModel):
|
||||
"""
|
||||
Metadata Filtering Condition.
|
||||
"""
|
||||
|
||||
logical_operator: Literal["and", "or"] | None = "and"
|
||||
conditions: list[Condition] | None = Field(default=None, deprecated=True)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
"""
|
||||
Dataset Retrieve Config Entity.
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: ConversationSnapshot,
|
||||
message: MessageSnapshot,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
||||
) -> Mapping | Generator[Mapping | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
def _error_to_stream_response(cls, e: Exception):
|
||||
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
error_responses: dict[type[Exception], dict[str, Any]] = {
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
data = None
|
||||
data: dict[str, Any] | None = None
|
||||
for k, v in error_responses.items():
|
||||
if isinstance(e, k):
|
||||
data = v
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
from typing import Any, Literal, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
|
||||
# Add null check for dataset
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
account: Account | EndUser,
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
start_node_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Format datasource info list.
|
||||
|
||||
@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
"""
|
||||
@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.system_variables import (
|
||||
build_bootstrap_variables,
|
||||
|
||||
@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class QueueEvent(StrEnum):
|
||||
|
||||
@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
||||
entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
||||
type _GenerateEntityUnion = Annotated[
|
||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| CompletionAppBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
|
||||
):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@ -511,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: MessageAgentThought | None = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
agent_thought: MessageAgentThought | None = session.scalar(
|
||||
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
|
||||
@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
@ -10,7 +10,7 @@ associates with the node span.
|
||||
import logging
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final
|
||||
from typing import cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
@ -18,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
from opentelemetry import context as context_api
|
||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.otel.parser import (
|
||||
|
||||
@ -6,7 +6,7 @@ from sqlalchemy import select, update
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
||||
@ -345,8 +345,8 @@ class DatasourceManager:
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
upload_file = session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
@ -1,22 +1,3 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: str | None = Field(default=None)
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _(self):
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
__all__ = ["I18nObject", "I18nObjectDict"]
|
||||
|
||||
@ -9,7 +9,7 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
|
||||
@ -1 +1,8 @@
|
||||
from core.entities.plugin_credential_type import PluginCredentialType
|
||||
|
||||
DEFAULT_PLUGIN_ID = "langgenius"
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_PLUGIN_ID",
|
||||
"PluginCredentialType",
|
||||
]
|
||||
|
||||
@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
|
||||
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
|
||||
9
api/core/entities/plugin_credential_type.py
Normal file
9
api/core/entities/plugin_credential_type.py
Normal file
@ -0,0 +1,9 @@
|
||||
import enum
|
||||
|
||||
|
||||
class PluginCredentialType(enum.Enum):
|
||||
MODEL = 0 # must be 0 for API contract compatibility
|
||||
TOOL = 1 # must be 1 for API contract compatibility
|
||||
|
||||
def to_number(self):
|
||||
return self.value
|
||||
@ -22,6 +22,7 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import PluginCredentialType
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@ -46,7 +47,6 @@ from models.provider import (
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Credential utility functions for checking credential existence and policy compliance.
|
||||
"""
|
||||
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from core.entities import PluginCredentialType
|
||||
|
||||
|
||||
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:
|
||||
|
||||
@ -2,12 +2,13 @@ import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import AnyStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
||||
def import_module_from_source[T: (str, bytes)](
|
||||
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Importing a module from the source file directly
|
||||
"""
|
||||
|
||||
@ -2,7 +2,6 @@ import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import TypeVar
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
|
||||
return position_map
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_filtered(
|
||||
def is_filtered[T](
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: T,
|
||||
@ -97,11 +93,11 @@ def is_filtered(
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
def sort_by_position_map[T](
|
||||
position_map: dict[str, int],
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
) -> list[T]:
|
||||
"""
|
||||
Sort the objects by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
@ -116,11 +112,11 @@ def sort_by_position_map(
|
||||
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
|
||||
|
||||
|
||||
def sort_to_dict_by_position_map(
|
||||
def sort_to_dict_by_position_map[T](
|
||||
position_map: dict[str, int],
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
) -> OrderedDict[str, T]:
|
||||
"""
|
||||
Sort the objects into a ordered dict by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
|
||||
@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
Headers: TypeAlias = dict[str, str]
|
||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
||||
type Headers = dict[str, str]
|
||||
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
from typing import Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
@ -49,6 +49,17 @@ class WorkflowServiceInterface(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class CodeGenerateResultDict(TypedDict):
|
||||
code: str
|
||||
language: str
|
||||
error: str
|
||||
|
||||
|
||||
class StructuredOutputResultDict(TypedDict):
|
||||
output: str
|
||||
error: str
|
||||
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(
|
||||
@ -293,7 +304,7 @@ class LLMGenerator:
|
||||
cls,
|
||||
tenant_id: str,
|
||||
args: RuleCodeGeneratePayload,
|
||||
):
|
||||
) -> CodeGenerateResultDict:
|
||||
if args.code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
@ -362,7 +373,9 @@ class LLMGenerator:
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
def generate_structured_output(
|
||||
cls, tenant_id: str, args: RuleStructuredOutputPayload
|
||||
) -> StructuredOutputResultDict:
|
||||
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
@ -454,7 +467,7 @@ class LLMGenerator:
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
app: App | None = session.query(App).where(App.id == flow_id).first()
|
||||
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
|
||||
if not app:
|
||||
raise ValueError("App not found.")
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user