Compare commits

..

1 Commits

Author SHA1 Message Date
yyh
a33791c3de force node24 for legacy GitHub actions 2026-03-30 15:39:14 +08:00
2439 changed files with 29964 additions and 38671 deletions

View File

@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
// Component only adds UI behavior.
updateAccessMode({ appId, mode }, {
onSuccess: () => toast.success('...'),
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
})
// Avoid putting invalidation knowledge in the component.
@ -114,7 +114,10 @@ try {
router.push(`/orders/${order.id}`)
}
catch (error) {
toast.error(error instanceof Error ? error.message : 'Unknown error')
Toast.notify({
type: 'error',
message: error instanceof Error ? error.message : 'Unknown error',
})
}
```

View File

@ -6,6 +6,7 @@ runs:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
working-directory: web
node-version-file: .nvmrc
cache: true
run-install: true

View File

@ -9,6 +9,9 @@ on:
permissions:
contents: read
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
concurrency:
group: api-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
@ -35,7 +38,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -84,7 +87,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -156,7 +159,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"
@ -203,7 +206,7 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
files: ./coverage.xml
disable_search: true

View File

@ -10,6 +10,9 @@ on:
permissions:
contents: read
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
jobs:
autofix:
if: github.repository == 'langgenius/dify'
@ -39,10 +42,6 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes
@ -56,7 +55,7 @@ jobs:
python-version: "3.11"
- if: github.event_name != 'merge_group'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
- name: Generate Docker Compose
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@ -24,39 +24,27 @@ env:
jobs:
build:
runs-on: ${{ matrix.runs_on }}
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
if: github.repository == 'langgenius/dify'
strategy:
matrix:
include:
- service_name: "build-api-amd64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
steps:
- name: Prepare
@ -70,6 +58,9 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@ -83,8 +74,7 @@ jobs:
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
@ -103,7 +93,7 @@ jobs:
- name: Upload digest
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1

View File

@ -3,6 +3,9 @@ name: DB Migration Test
on:
workflow_call:
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
concurrency:
group: db-migration-test-${{ github.ref }}
cancel-in-progress: true
@ -19,7 +22,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"
@ -69,7 +72,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -6,12 +6,7 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}
@ -19,31 +14,26 @@ concurrency:
jobs:
build-docker:
runs-on: ${{ matrix.runs_on }}
runs-on: ubuntu-latest
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@ -51,8 +41,8 @@ jobs:
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -16,6 +16,9 @@ permissions:
checks: write
statuses: write
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
concurrency:
group: main-ci-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
@ -65,10 +68,6 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
e2e:
@ -77,10 +76,6 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
- '.github/workflows/web-e2e.yml'

View File

@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
@ -50,17 +50,6 @@ jobs:
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Check if line counts match
id: line_count_check
run: |
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
if [ "$base_lines" -eq "$pr_lines" ]; then
echo "same=true" >> $GITHUB_OUTPUT
else
echo "same=false" >> $GITHUB_OUTPUT
fi
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
@ -74,7 +63,7 @@ jobs:
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: false
python-version: "3.12"
@ -77,10 +77,6 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@ -94,9 +90,9 @@ jobs:
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@ -6,9 +6,6 @@ on:
- main
paths:
- sdks/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -1,10 +1,10 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are bridged by trigger-i18n-sync.yml via repository_dispatch.
on:
repository_dispatch:
types: [i18n-sync]
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
inputs:
files:
@ -30,7 +30,7 @@ permissions:
concurrency:
group: translate-i18n-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: false
cancel-in-progress: ${{ github.event_name == 'push' }}
jobs:
translate:
@ -67,113 +67,19 @@ jobs:
}
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() {
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
}
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
BASE_SHA="${{ github.event.client_payload.base_sha }}"
HEAD_SHA="${{ github.event.client_payload.head_sha }}"
CHANGED_FILES="${{ github.event.client_payload.changed_files }}"
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
SYNC_MODE="${{ github.event.client_payload.sync_mode || 'incremental' }}"
if [ -n "${{ github.event.client_payload.changes_base64 }}" ]; then
printf '%s' '${{ github.event.client_payload.changes_base64 }}' | base64 -d > /tmp/i18n-changes.json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="embedded"
elif [ -n "$BASE_SHA" ] && [ -n "$CHANGED_FILES" ]; then
export BASE_SHA HEAD_SHA CHANGED_FILES
generate_changes_json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="recomputed"
else
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
CHANGES_AVAILABLE="false"
CHANGES_SOURCE="unavailable"
if [ "${{ github.event_name }}" = "push" ]; then
BASE_SHA="${{ github.event.before }}"
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
fi
HEAD_SHA="${{ github.sha }}"
if [ -n "$BASE_SHA" ]; then
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
else
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
fi
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
SYNC_MODE="incremental"
else
BASE_SHA=""
HEAD_SHA=$(git rev-parse HEAD)
@ -198,17 +104,6 @@ jobs:
else
CHANGED_FILES=""
fi
if [ "$SYNC_MODE" = "incremental" ] && [ -n "$CHANGED_FILES" ]; then
export BASE_SHA HEAD_SHA CHANGED_FILES
generate_changes_json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="local"
else
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
CHANGES_AVAILABLE="false"
CHANGES_SOURCE="unavailable"
fi
fi
FILE_ARGS=""
@ -228,8 +123,6 @@ jobs:
echo "CHANGED_FILES=$CHANGED_FILES"
echo "TARGET_LANGS=$TARGET_LANGS"
echo "SYNC_MODE=$SYNC_MODE"
echo "CHANGES_AVAILABLE=$CHANGES_AVAILABLE"
echo "CHANGES_SOURCE=$CHANGES_SOURCE"
echo "FILE_ARGS=$FILE_ARGS"
echo "LANG_ARGS=$LANG_ARGS"
} >> "$GITHUB_OUTPUT"
@ -248,7 +141,7 @@ jobs:
show_full_output: ${{ github.event_name == 'workflow_dispatch' }}
prompt: |
You are the i18n sync agent for the Dify repository.
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`.
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result.
Use absolute paths at all times:
- Repo root: `${{ github.workspace }}`
@ -263,15 +156,12 @@ jobs:
- Head SHA: `${{ steps.context.outputs.HEAD_SHA }}`
- Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}`
- Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}`
- Structured change set available: `${{ steps.context.outputs.CHANGES_AVAILABLE }}`
- Structured change set source: `${{ steps.context.outputs.CHANGES_SOURCE }}`
- Structured change set file: `/tmp/i18n-changes.json`
Tool rules:
- Use Read for repository files.
- Use Edit for JSON updates.
- Use Bash only for `pnpm`.
- Do not use Bash for `git`, `gh`, or branch management.
- Use Bash only for `git`, `gh`, `pnpm`, and `date`.
- Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution.
Required execution plan:
1. Resolve target languages.
@ -282,25 +172,27 @@ jobs:
- Only process the resolved target languages, never `en-US`.
- Do not touch unrelated i18n files.
- Do not modify `${{ github.workspace }}/web/i18n/en-US/`.
3. Resolve source changes.
- If `Structured change set available` is `true`, read `/tmp/i18n-changes.json` and use it as the source of truth for file-level and key-level changes.
- For each file entry:
- `added` contains new English keys that need translations.
- `updated` contains stale keys whose English source changed; re-translate using the `after` value.
- `deleted` contains keys that should be removed from locale files.
- `fileDeleted: true` means the English file no longer exists; remove the matching locale file if present.
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
3. Detect English changes per file.
- Read the current English JSON file for each file in scope.
- If sync mode is `incremental` and `Base SHA` is not empty, run:
`git -C ${{ github.workspace }} show <Base SHA>:web/i18n/en-US/<file>.json`
- If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync.
- If the file did not exist at Base SHA, treat all current keys as ADD.
- Compare previous and current English JSON to identify:
- ADD: key only in current
- UPDATE: key exists in both and the English value changed
- DELETE: key only in previous
- Do not rely on a truncated diff file.
4. Run a scoped pre-check before editing:
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations.
- For every target language and scoped file:
- If `fileDeleted` is `true`, remove the locale file if it exists and skip the rest of that file.
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys.
- UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching.
@ -308,119 +200,14 @@ jobs:
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests.
7. Create a PR only when there are changes in `web/i18n/`.
- Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/`
- Create branch `chore/i18n-sync-<timestamp>`
- Commit message: `chore(i18n): sync translations with en-US`
- Push the branch and open a PR against `main`
- PR title: `chore(i18n): sync translations with en-US`
- PR body: summarize files, languages, sync mode, and verification commands
8. If there are no translation changes after verification, do not create a branch, commit, or PR.
claude_args: |
--max-turns 120
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
- name: Prepare branch metadata
id: pr_meta
if: steps.context.outputs.CHANGED_FILES != ''
shell: bash
run: |
if [ -z "$(git -C "${{ github.workspace }}" status --porcelain -- web/i18n/)" ]; then
echo "has_changes=false" >> "$GITHUB_OUTPUT"
exit 0
fi
SCOPE_HASH=$(printf '%s|%s|%s' "${{ steps.context.outputs.CHANGED_FILES }}" "${{ steps.context.outputs.TARGET_LANGS }}" "${{ steps.context.outputs.SYNC_MODE }}" | sha256sum | cut -c1-8)
HEAD_SHORT=$(printf '%s' "${{ steps.context.outputs.HEAD_SHA }}" | cut -c1-12)
BRANCH_NAME="chore/i18n-sync-${HEAD_SHORT}-${SCOPE_HASH}"
{
echo "has_changes=true"
echo "branch_name=$BRANCH_NAME"
} >> "$GITHUB_OUTPUT"
- name: Commit translation changes
if: steps.pr_meta.outputs.has_changes == 'true'
shell: bash
run: |
git -C "${{ github.workspace }}" checkout -B "${{ steps.pr_meta.outputs.branch_name }}"
git -C "${{ github.workspace }}" add web/i18n/
git -C "${{ github.workspace }}" commit -m "chore(i18n): sync translations with en-US"
- name: Push translation branch
if: steps.pr_meta.outputs.has_changes == 'true'
shell: bash
run: |
if git -C "${{ github.workspace }}" ls-remote --exit-code --heads origin "${{ steps.pr_meta.outputs.branch_name }}" >/dev/null 2>&1; then
git -C "${{ github.workspace }}" push --force-with-lease origin "${{ steps.pr_meta.outputs.branch_name }}"
else
git -C "${{ github.workspace }}" push --set-upstream origin "${{ steps.pr_meta.outputs.branch_name }}"
fi
- name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true'
env:
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
SYNC_MODE: ${{ steps.context.outputs.SYNC_MODE }}
CHANGES_SOURCE: ${{ steps.context.outputs.CHANGES_SOURCE }}
BASE_SHA: ${{ steps.context.outputs.BASE_SHA }}
HEAD_SHA: ${{ steps.context.outputs.HEAD_SHA }}
REPO_NAME: ${{ github.repository }}
shell: bash
run: |
PR_BODY_FILE=/tmp/i18n-pr-body.md
LANG_COUNT=$(printf '%s\n' "$TARGET_LANGS" | wc -w | tr -d ' ')
if [ "$LANG_COUNT" = "0" ]; then
LANG_COUNT="0"
fi
export LANG_COUNT
node <<'NODE' > "$PR_BODY_FILE"
const fs = require('node:fs')
const changesPath = '/tmp/i18n-changes.json'
const changes = fs.existsSync(changesPath)
? JSON.parse(fs.readFileSync(changesPath, 'utf8'))
: { changes: {} }
const filesInScope = (process.env.FILES_IN_SCOPE || '').split(/\s+/).filter(Boolean)
const lines = [
'## Summary',
'',
`- **Files synced**: \`${process.env.FILES_IN_SCOPE || '<none>'}\``,
`- **Languages updated**: ${process.env.TARGET_LANGS || '<none>'} (${process.env.LANG_COUNT} languages)`,
`- **Sync mode**: ${process.env.SYNC_MODE}${process.env.BASE_SHA ? ` (base: \`${process.env.BASE_SHA.slice(0, 10)}\`, head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)` : ` (head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)`}`,
'',
'### Key changes',
]
for (const fileName of filesInScope) {
const fileChange = changes.changes?.[fileName] || { added: {}, updated: {}, deleted: [], fileDeleted: false }
const addedKeys = Object.keys(fileChange.added || {})
const updatedKeys = Object.keys(fileChange.updated || {})
const deletedKeys = fileChange.deleted || []
lines.push(`- \`${fileName}\`: +${addedKeys.length} / ~${updatedKeys.length} / -${deletedKeys.length}${fileChange.fileDeleted ? ' (file deleted in en-US)' : ''}`)
}
lines.push(
'',
'## Verification',
'',
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
'',
'## Notes',
'',
'- This PR was generated from structured en-US key changes produced by `trigger-i18n-sync.yml`.',
`- Structured change source: ${process.env.CHANGES_SOURCE || 'unknown'}.`,
'- Branch name is deterministic for the head SHA and scope, so reruns update the same PR instead of opening duplicates.',
'',
'🤖 Generated with [Claude Code](https://claude.com/claude-code)'
)
process.stdout.write(lines.join('\n'))
NODE
EXISTING_PR_NUMBER=$(gh pr list --repo "$REPO_NAME" --head "$BRANCH_NAME" --state open --json number --jq '.[0].number')
if [ -n "$EXISTING_PR_NUMBER" ] && [ "$EXISTING_PR_NUMBER" != "null" ]; then
gh pr edit "$EXISTING_PR_NUMBER" --repo "$REPO_NAME" --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
else
gh pr create --repo "$REPO_NAME" --head "$BRANCH_NAME" --base main --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
fi
--max-turns 80
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

View File

@ -1,171 +0,0 @@
name: Trigger i18n Sync on Push
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
concurrency:
group: trigger-i18n-sync-${{ github.ref }}
cancel-in-progress: true
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Detect changed files and build structured change set
id: detect
shell: bash
run: |
BASE_SHA="${{ github.event.before }}"
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
fi
HEAD_SHA="${{ github.sha }}"
if [ -n "$BASE_SHA" ]; then
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
else
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
fi
export BASE_SHA HEAD_SHA CHANGED_FILES
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"
else
echo "has_changes=false" >> "$GITHUB_OUTPUT"
fi
echo "base_sha=$BASE_SHA" >> "$GITHUB_OUTPUT"
echo "head_sha=$HEAD_SHA" >> "$GITHUB_OUTPUT"
echo "changed_files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
env:
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
CHANGED_FILES: ${{ steps.detect.outputs.changed_files }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs')
const changesJson = fs.readFileSync('/tmp/i18n-changes.json', 'utf8')
const changesBase64 = Buffer.from(changesJson).toString('base64')
const maxEmbeddedChangesChars = 48000
const changesEmbedded = changesBase64.length <= maxEmbeddedChangesChars
if (!changesEmbedded) {
console.log(`Structured change set too large to embed safely (${changesBase64.length} chars). Downstream workflow will regenerate it from git history.`)
}
await github.rest.repos.createDispatchEvent({
owner: context.repo.owner,
repo: context.repo.repo,
event_type: 'i18n-sync',
client_payload: {
changed_files: process.env.CHANGED_FILES,
changes_base64: changesEmbedded ? changesBase64 : '',
changes_embedded: changesEmbedded,
sync_mode: 'incremental',
base_sha: process.env.BASE_SHA,
head_sha: process.env.HEAD_SHA,
},
})

View File

@ -1,95 +0,0 @@
name: Run Full VDB Tests
on:
schedule:
- cron: '0 3 * * 1'
workflow_dispatch:
permissions:
contents: read
concurrency:
group: vdb-tests-full-${{ github.ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.12"
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
cache-dependency-glob: api/uv.lock
- name: Check UV lockfile
run: uv lock --project api --check
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
# compose-file: docker/tidb/docker-compose.yaml
# services: |
# tidb
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -1,10 +1,10 @@
name: Run VDB Smoke Tests
name: Run VDB Tests
on:
workflow_call:
permissions:
contents: read
env:
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}
@ -12,7 +12,7 @@ concurrency:
jobs:
test:
name: VDB Smoke Tests
name: VDB Tests
runs-on: ubuntu-latest
strategy:
matrix:
@ -33,7 +33,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@ -61,18 +61,23 @@ jobs:
# tidb
# tiflash
- name: Set up Vector Stores for Smoke Coverage
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
db_postgres
redis
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
@ -84,9 +89,4 @@ jobs:
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@ -27,8 +27,12 @@ jobs:
- name: Setup web dependencies
uses: ./.github/actions/setup-web
- name: Install E2E package dependencies
working-directory: ./e2e
run: vp install --frozen-lockfile
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: "3.12"

View File

@ -83,9 +83,40 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
with:
directory: web/coverage
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
web-build:
name: Web Build
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run build

4
.gitignore vendored
View File

@ -212,8 +212,6 @@ api/.vscode
# pnpm
/.pnpm-store
node_modules
.vite-hooks/_
# plugin migrate
plugins.jsonl
@ -241,4 +239,4 @@ scripts/stress-test/reports/
*.local.md
# Code Agent Folder
.qoder/*
.qoder/*

1
.npmrc
View File

@ -1 +0,0 @@
save-exact=true

View File

@ -24,8 +24,8 @@ prepare-docker:
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
@pnpm install
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
@ -93,7 +93,7 @@ test:
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
docker build -t $(WEB_IMAGE):$(VERSION) ./web
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
build-api:

View File

@ -115,6 +115,12 @@ ignore = [
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]

View File

@ -40,8 +40,6 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/start-web
```
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
1. Set up your application by visiting `http://localhost:3000`.
1. Start the worker service (async and scheduler tasks, runs from `api`).

View File

@ -7,16 +7,15 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
_IMAGE_EXTENSION_BASE: frozenset[str] = frozenset(("jpg", "jpeg", "png", "webp", "gif", "svg"))
_VIDEO_EXTENSION_BASE: frozenset[str] = frozenset(("mp4", "mov", "mpeg", "webm"))
_AUDIO_EXTENSION_BASE: frozenset[str] = frozenset(("mp3", "m4a", "wav", "amr", "mpga"))
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
IMAGE_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_IMAGE_EXTENSION_BASE))
VIDEO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_VIDEO_EXTENSION_BASE))
AUDIO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_AUDIO_EXTENSION_BASE))
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
(
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = {
"txt",
"markdown",
"md",
@ -36,10 +35,11 @@ _UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
"pptx",
"xml",
"epub",
)
)
_DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
(
}
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.add("ppt")
else:
_doc_extensions = {
"txt",
"markdown",
"md",
@ -53,17 +53,8 @@ _DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
"csv",
"vtt",
"properties",
)
)
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = set(_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE)
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.add("ppt")
else:
_doc_extensions = set(_DEFAULT_DOCUMENT_EXTENSION_BASE)
DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_doc_extensions))
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from pydantic import BaseModel
@ -188,6 +188,8 @@ class ExecutionContextBuilder:
_capturer: Callable[[], IExecutionContext] | None = None
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing."""

View File

@ -1,4 +1,7 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
@ -8,7 +11,7 @@ class HiddenValue:
_default = HiddenValue()
class RecyclableContextVar[T]:
class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now

View File

@ -1,14 +1,14 @@
from __future__ import annotations
from typing import Any
from typing import Any, TypeAlias
from graphon.file import helpers as file_helpers
from pydantic import BaseModel, ConfigDict, computed_field
from models.model import IconType
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
type JSONObject = dict[str, Any]
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
class SystemParameters(BaseModel):

View File

@ -4,8 +4,8 @@ from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES: frozenset[str] = frozenset(("text/html", "application/xhtml+xml"))
HTML_EXTENSIONS: frozenset[str] = frozenset(("html", "htm"))
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:

View File

@ -2,6 +2,7 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
@ -19,6 +20,9 @@ from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -68,9 +72,9 @@ console_ns.schema_model(
)
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def admin_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@ -2,7 +2,7 @@ import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
@ -34,7 +34,7 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()

View File

@ -1,7 +1,7 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal
from typing import Any, Literal, TypeAlias
from flask import request
from flask_restx import Resource
@ -9,7 +9,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
return value
type JSONValue = Any
JSONValue: TypeAlias = Any
class ResponseModel(BaseModel):
@ -642,7 +642,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@ -655,6 +655,7 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -87,6 +87,7 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@ -111,11 +112,12 @@ class AppImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -132,7 +134,7 @@ class AppImportCheckDependenciesApi(Resource):
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource):
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {

View File

@ -9,8 +9,8 @@ from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import sessionmaker
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
@ -268,18 +268,22 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
elif "text/plain" in content_type:
try:
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
payload_data = json.loads(request.data.decode("utf-8"))
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()
@ -836,7 +840,7 @@ class PublishedWorkflowApi(Resource):
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
@ -854,6 +858,8 @@ class PublishedWorkflowApi(Resource):
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
@ -976,7 +982,7 @@ class PublishedAllWorkflowApi(Resource):
raise Forbidden()
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflows, has_more = workflow_service.get_all_published_workflow(
session=session,
app_model=app_model,
@ -1066,7 +1072,7 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService()
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.update_workflow(
session=session,
workflow_id=workflow_id,
@ -1078,6 +1084,9 @@ class WorkflowByIdApi(Resource):
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@ -1092,11 +1101,13 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService()
# Create a session and manage the transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:

View File

@ -5,7 +5,7 @@ from flask import request
from flask_restx import Resource, marshal_with
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import Any, NoReturn, ParamSpec, TypeVar
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.error import (
@ -192,8 +192,11 @@ workflow_draft_variable_list_model = console_ns.model(
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
)
P = ParamSpec("P")
R = TypeVar("R")
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite(f: Callable[P, R]):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -210,7 +213,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(*args: P.args, **kwargs: P.kwargs):
return f(*args, **kwargs)
return wrapper
@ -241,7 +244,7 @@ class WorkflowVariableCollectionApi(Resource):
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@ -267,7 +270,7 @@ class WorkflowVariableCollectionApi(Resource):
return Response("", 204)
def validate_node_id(node_id: str) -> None:
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
@ -282,6 +285,7 @@ def validate_node_id(node_id: str) -> None:
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
@ -294,7 +298,7 @@ class NodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@ -461,7 +465,7 @@ class VariableResetApi(Resource):
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource):
assert current_user.current_tenant_id is not None
trigger_id = args.trigger_id
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
@ -153,6 +153,9 @@ class AppTriggerEnableApi(Resource):
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import overload
from typing import ParamSpec, TypeVar, Union
from sqlalchemy import select
@ -9,6 +9,11 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
P = ParamSpec("P")
R = TypeVar("R")
P1 = ParamSpec("P1")
R1 = TypeVar("R1")
def _load_app_model(app_id: str) -> App | None:
_, current_tenant_id = current_account_with_tenant()
@ -23,30 +28,10 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
return app_model
@overload
def get_app_model[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R]: ...
@overload
def get_app_model[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
@ -84,30 +69,10 @@ def get_app_model[**P, R](
return decorator(view)
@overload
def get_app_model_with_trial[**P, R](
view: Callable[P, R],
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R]: ...
@overload
def get_app_model_with_trial[**P, R](
view: None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def get_app_model_with_trial[**P, R](
view: Callable[P, R] | None = None,
*,
mode: AppMode | list[AppMode] | None = None,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")

View File

@ -1,9 +1,8 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from typing import Concatenate, ParamSpec, TypeVar
from flask import jsonify, request
from flask.typing import ResponseReturnValue
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel
@ -17,6 +16,10 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
from .. import console_ns
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class OAuthClientPayload(BaseModel):
client_id: str
@ -36,11 +39,9 @@ class OAuthTokenRequest(BaseModel):
refresh_token: str | None = None
def oauth_server_client_id_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
) -> Callable[Concatenate[T, P], R]:
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
json_data = request.get_json()
if json_data is None:
raise BadRequest("client_id is required")
@ -57,13 +58,9 @@ def oauth_server_client_id_required[T, **P, R](
return decorated
def oauth_server_access_token_required[T, **P, R](
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
@wraps(view)
def decorated(
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
) -> R | ResponseReturnValue:
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
if not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")

View File

@ -36,7 +36,7 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)

View File

@ -31,7 +31,7 @@ class ComplianceApi(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")

View File

@ -6,7 +6,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import get_or_create_model, register_schema_model
@ -158,11 +158,10 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
_, current_tenant_id = current_account_with_tenant()
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None:
raise NotFound("Data source binding not found.")
@ -212,7 +211,7 @@ class DataSourceNotionListApi(Resource):
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)

View File

@ -173,11 +173,8 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id, current_tenant_id
)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if external_knowledge_api is None:
raise NotFound("API template not found.")

View File

@ -120,8 +120,7 @@ class DatasourceOAuthCallback(Resource):
if context is None:
raise Forbidden("Invalid context_id")
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
datasource_provider_service = DatasourceProviderService()
@ -142,7 +141,7 @@ class DatasourceOAuthCallback(Resource):
system_credentials=oauth_client_params,
request=request,
)
credential_id: str | None = context.get("credential_id")
credential_id = context.get("credential_id")
if credential_id:
datasource_provider_service.reauthorize_datasource_oauth_provider(
tenant_id=tenant_id,
@ -151,7 +150,7 @@ class DatasourceOAuthCallback(Resource):
name=oauth_response.metadata.get("name") or None,
expire_at=oauth_response.expires_at,
credentials=dict(oauth_response.credentials),
credential_id=credential_id,
credential_id=context.get("credential_id"),
)
else:
datasource_provider_service.add_datasource_oauth_provider(

View File

@ -3,7 +3,7 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@ -85,7 +85,7 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,

View File

@ -1,12 +1,11 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
@ -56,7 +55,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -71,7 +70,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(*args, **kwargs):
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
@ -97,7 +96,7 @@ class RagPipelineVariableCollectionApi(Resource):
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@ -144,7 +143,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@ -290,7 +289,7 @@ class RagPipelineVariableResetApi(Resource):
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)

View File

@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource):
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
@ -80,6 +80,7 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
session.commit()
# Return appropriate status code based on result
status = result.status
@ -101,11 +102,12 @@ class RagPipelineImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -122,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@ -140,7 +142,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@ -5,8 +5,8 @@ from typing import Any, Literal, cast
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import sessionmaker
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
@ -186,14 +186,29 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type:
payload_dict = console_ns.payload or {}
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
elif "text/plain" in content_type:
try:
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
rag_pipeline_service = RagPipelineService()
try:
@ -593,15 +608,19 @@ class PublishedRagPipelineApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.publish_workflow(
session=db.session, # type: ignore[reportArgumentType,arg-type]
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
workflow = rag_pipeline_service.publish_workflow(
session=session,
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
session.add(pipeline)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
@ -676,7 +695,7 @@ class PublishedAllRagPipelineApi(Resource):
raise Forbidden()
rag_pipeline_service = RagPipelineService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
session=session,
pipeline=pipeline,
@ -748,7 +767,7 @@ class RagPipelineByIdApi(Resource):
rag_pipeline_service = RagPipelineService()
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow = rag_pipeline_service.update_workflow(
session=session,
workflow_id=workflow_id,
@ -760,6 +779,9 @@ class RagPipelineByIdApi(Resource):
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@ -776,13 +798,14 @@ class RagPipelineByIdApi(Resource):
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=pipeline.tenant_id,
)
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy import select
@ -8,10 +9,13 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.dataset import Pipeline
P = ParamSpec("P")
R = TypeVar("R")
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def get_rag_pipeline(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")

View File

@ -2,7 +2,7 @@ from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
@ -74,7 +74,7 @@ class ConversationListApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
@ -15,8 +15,12 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
@ -45,7 +49,7 @@ def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P],
return decorator
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
@wraps(view)
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
@ -69,7 +73,7 @@ def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp,
return decorator
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
@ -102,7 +106,7 @@ def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = N
return decorator
def trial_feature_enable[**P, R](view: Callable[P, R]):
def trial_feature_enable(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@ -113,7 +117,7 @@ def trial_feature_enable[**P, R](view: Callable[P, R]):
return decorated
def explore_banner_enabled[**P, R](view: Callable[P, R]):
def explore_banner_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()

View File

@ -1,26 +1,30 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models.account import TenantPluginPermission
P = ParamSpec("P")
R = TypeVar("R")
def plugin_permission_required(
install_required: bool = False,
debug_required: bool = False,
):
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
current_user, current_tenant_id = current_account_with_tenant()
user = current_user
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.where(

View File

@ -8,7 +8,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
@ -519,7 +519,7 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@ -562,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()

View File

@ -99,7 +99,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
@ -118,7 +118,7 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()

View File

@ -287,10 +287,12 @@ class ModelProviderModelCredentialApi(Resource):
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=args.model_type,
model_type=normalized_model_type,
model=args.model,
)

View File

@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@ -832,8 +832,7 @@ class ToolOAuthCallback(Resource):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
@ -1019,7 +1018,7 @@ class ToolProviderMCPApi(Resource):
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=payload.provider_id
@ -1034,7 +1033,7 @@ class ToolProviderMCPApi(Resource):
)
# Step 3: Perform database update in a transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
@ -1061,7 +1060,7 @@ class ToolProviderMCPApi(Resource):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
@ -1079,7 +1078,7 @@ class ToolMCPAuthApi(Resource):
provider_id = payload.provider_id
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
@ -1100,7 +1099,7 @@ class ToolMCPAuthApi(Resource):
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
@ -1118,17 +1117,17 @@ class ToolMCPAuthApi(Resource):
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@ -1141,7 +1140,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@ -1155,7 +1154,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
@ -1170,7 +1169,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
@ -1188,7 +1187,7 @@ class ToolMCPCallbackApi(Resource):
authorization_code = query.code
# Create service instance for handle_callback
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)

View File

@ -5,7 +5,7 @@ from flask import make_response, redirect, request
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource):
assert user.current_tenant_id is not None
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
@ -388,6 +388,7 @@ class TriggerSubscriptionDeleteApi(Resource):
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
@ -498,9 +499,9 @@ class TriggerOAuthCallbackApi(Resource):
provider_id = TriggerProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
user_id: str = context["user_id"]
tenant_id: str = context["tenant_id"]
subscription_builder_id: str = context["subscription_builder_id"]
user_id = context.get("user_id")
tenant_id = context.get("tenant_id")
subscription_builder_id = context.get("subscription_builder_id")
# Get OAuth client configuration
oauth_client_params = TriggerProviderService.get_oauth_client(

View File

@ -155,7 +155,7 @@ class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = WorkspaceListQuery.model_validate(payload)
stmt = select(Tenant).order_by(Tenant.created_at.desc())

View File

@ -4,6 +4,7 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, request
from sqlalchemy import select
@ -24,6 +25,9 @@ from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
P = ParamSpec("P")
R = TypeVar("R")
# Field names for decryption
FIELD_NAME_PASSWORD = "password"
FIELD_NAME_CODE = "code"
@ -33,7 +37,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check account initialization
@ -46,7 +50,7 @@ def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P,
return decorated
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def only_edition_cloud(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "CLOUD":
@ -57,7 +61,7 @@ def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def only_edition_enterprise(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.ENTERPRISE_ENABLED:
@ -68,7 +72,7 @@ def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def only_edition_self_hosted(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if dify_config.EDITION != "SELF_HOSTED":
@ -79,7 +83,7 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_edition_billing_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@ -91,7 +95,7 @@ def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R
return decorated
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_resource_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@ -133,9 +137,7 @@ def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Cal
return interceptor
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_knowledge_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@ -158,7 +160,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
return interceptor
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@ -194,7 +196,7 @@ def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[C
return interceptor
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_utm_record(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
@ -213,7 +215,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
# check setup
@ -227,7 +229,7 @@ def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_license_required(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
settings = FeatureService.get_system_features()
@ -239,7 +241,7 @@ def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def email_password_login_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@ -252,7 +254,7 @@ def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]
return decorated
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def email_register_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@ -265,7 +267,7 @@ def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enable_change_email(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
features = FeatureService.get_system_features()
@ -278,7 +280,7 @@ def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
from libs.workspace_permission import check_workspace_owner_transfer_permission
@ -291,7 +293,7 @@ def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
@ -303,7 +305,7 @@ def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable
return decorated
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
def edit_permission_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@ -321,7 +323,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
return decorated_function
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
def is_admin_or_owner_required(f: Callable[P, R]):
@wraps(f)
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
@ -337,7 +339,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
return decorated_function
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def annotation_import_rate_limit(view: Callable[P, R]):
"""
Rate limiting decorator for annotation import operations.
@ -386,7 +388,7 @@ def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]
return decorated
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def annotation_import_concurrency_limit(view: Callable[P, R]):
"""
Concurrency control decorator for annotation import operations.
@ -453,7 +455,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
payload[field_name] = decoded_value
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def decrypt_password_field(view: Callable[P, R]):
"""
Decorator to decrypt password field in request payload.
@ -475,7 +477,7 @@ def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def decrypt_code_field(view: Callable[P, R]):
"""
Decorator to decrypt verification code field in request payload.

View File

@ -1,17 +1,21 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.login import current_user
from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
@ -29,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
user_model = None
if is_anonymous:
@ -52,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
session_id=user_id,
)
session.add(user_model)
session.flush()
session.commit()
session.refresh(user_model)
except Exception:
@ -61,9 +65,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
return user_model
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def get_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
user_id = payload.user_id
@ -93,14 +97,10 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
return decorated_view
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()
except Exception:

View File

@ -3,7 +3,10 @@ from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -11,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -27,9 +30,9 @@ def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -43,9 +46,9 @@ def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def enterprise_inner_api_user_auth(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
return view(*args, **kwargs)
@ -79,9 +82,9 @@ def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P,
return decorated
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -4,7 +4,7 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.mcp import mcp_ns
@ -67,7 +67,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
@ -174,7 +174,6 @@ class MCPAppApi(Resource):
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
@ -189,7 +188,7 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session, session.begin():
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
@ -229,7 +228,9 @@ class MCPAppApi(Resource):
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -116,7 +116,7 @@ class ConversationApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,

View File

@ -29,31 +29,6 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel):
@ -157,7 +132,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@ -221,7 +196,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": _marshal_segments_with_summary(segments, dataset_id),
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -321,7 +296,7 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@ -351,7 +326,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.route(

View File

@ -1,10 +1,9 @@
import inspect
import logging
import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import cast, overload
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
from flask import current_app, request
from flask_login import user_logged_in
@ -24,6 +23,10 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
logger = logging.getLogger(__name__)
@ -43,16 +46,16 @@ class FetchUserArg(BaseModel):
@overload
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token[**P, R](
def validate_app_token(
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token[**P, R](
def validate_app_token(
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@ -133,10 +136,7 @@ def validate_app_token[**P, R](
return decorator(view)
def cloud_edition_billing_resource_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token(api_token_type)
@ -166,10 +166,7 @@ def cloud_edition_billing_resource_check[**P, R](
return interceptor
def cloud_edition_billing_knowledge_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@ -191,10 +188,7 @@ def cloud_edition_billing_knowledge_limit_check[**P, R](
return interceptor
def cloud_edition_billing_rate_limit_check[**P, R](
resource: str,
api_token_type: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
@ -231,73 +225,99 @@ def cloud_edition_billing_rate_limit_check[**P, R](
return interceptor
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
positional_parameters = [
parameter
for parameter in inspect.signature(view).parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@wraps(view)
def decorated(*args: object, **kwargs: object) -> R:
api_token = validate_and_get_api_token("dataset")
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
dataset_id = kwargs.get("dataset_id")
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
if not dataset_id and args:
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
# Flask passes URL path parameters as positional arguments
dataset_id = None
# First try to get from kwargs (explicit parameter)
dataset_id = kwargs.get("dataset_id")
# If not in kwargs, try to extract from positional args
if not dataset_id and args:
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
# Check if first arg is likely a class instance (has __dict__ or __class__)
if len(args) > 1 and hasattr(args[0], "__dict__"):
# This is a class method, dataset_id should be in args[1]
potential_id = args[1]
# Validate it's a string-like UUID, not another object
try:
# Try to convert to string and check if it's a valid UUID format
str_id = str(potential_id)
# Basic check: UUIDs are 36 chars with hyphens
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from class method args")
elif len(args) > 0:
# Not a class method, check if args[0] looks like a UUID
potential_id = args[0]
try:
str_id = str(potential_id)
if len(str_id) == 36 and str_id.count("-") == 4:
dataset_id = str_id
except Exception:
logger.exception("Failed to parse dataset_id from positional args")
# Validate dataset if dataset_id is provided
if dataset_id:
dataset_id = str(dataset_id)
dataset = db.session.scalar(
select(Dataset)
.where(
Dataset.id == dataset_id,
Dataset.tenant_id == api_token.tenant_id,
)
.limit(1)
)
.limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
if not dataset:
raise NotFound("Dataset not found.")
if not dataset.enable_api:
raise Forbidden("Dataset api access is not enabled.")
tenant_account_join = db.session.execute(
select(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL)
).one_or_none() # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.get(Account, ta.account_id)
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
if expects_bound_instance:
if not args:
raise TypeError("validate_dataset_token expected a bound resource instance.")
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
return decorated
return view(api_token.tenant_id, *args, **kwargs)
if view:
return decorator(view)
return decorated
# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
def validate_and_get_api_token(scope: str | None = None):

View File

@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
from controllers.trigger import bp
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
from services.trigger.webhook_service import WebhookService
logger = logging.getLogger(__name__)
@ -23,7 +23,6 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
webhook_id, is_debug=is_debug
)
webhook_data: RawWebhookDataDict
try:
# Use new unified extraction and validation
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)

View File

@ -2,7 +2,7 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
query = ConversationListQuery.model_validate(raw_args)
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@ -4,7 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
@ -180,17 +180,18 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt)
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AuthenticationFailedError()
return {"result": "success"}
def _update_existing_account(self, account: Account, password_hashed, salt):
def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()

View File

@ -1,12 +1,12 @@
from collections.abc import Callable
from datetime import UTC, datetime
from functools import wraps
from typing import Concatenate
from typing import Concatenate, ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from constants import HEADER_NAME_APP_CODE
@ -20,13 +20,14 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
P = ParamSpec("P")
R = TypeVar("R")
def validate_jwt_token[**P, R](
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: P.args, **kwargs: P.kwargs):
app_model, end_user = decode_jwt_token()
return view(app_model, end_user, *args, **kwargs)
@ -37,7 +38,7 @@ def validate_jwt_token[**P, R](
return decorator
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
system_features = FeatureService.get_system_features()
if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
@ -48,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) ->
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:

View File

@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
args: Mapping,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Account | EndUser,
args: LoopNodeRunPayload,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
self,
*,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager,
conversation: ConversationSnapshot,
message: MessageSnapshot,
user: Account | EndUser,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, overload
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
self,
*,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping | Generator[Mapping | str, None, None]: ...
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping | Generator[Mapping | str, None, None]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, overload
from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Generate App response.

View File

@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, overload
from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = False,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
def generate(
self,
app_model: App,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
message_id: str,
user: Account | EndUser,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Mapping | Generator[Mapping | str, None, None]:
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
"""
Generate App response.

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, cast, overload
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
*,
pipeline: Pipeline,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
with Session(db.engine, expire_on_commit=False) as session:
@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Account | EndUser,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: str | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.get(Workflow, workflow_id)
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Account | EndUser,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Account | EndUser,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
pipeline: Pipeline,
workflow: Workflow,
start_node_id: str,
user: Account | EndUser,
user: Union[Account, EndUser],
) -> list[Mapping[str, Any]]:
"""
Format datasource info list.

View File

@ -9,7 +9,6 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@ -85,13 +84,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.get(Pipeline, app_config.app_id)
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@ -214,10 +213,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.scalar(
select(Workflow)
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.limit(1)
.first()
)
# return workflow
@ -298,8 +297,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"

View File

@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from graphon.graph_engine.layers import GraphEngineLayer
@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Resume a paused workflow execution using the persisted runtime state.
"""
@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Account | EndUser,
user: Union[Account, EndUser],
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self
from typing import Annotated, Literal, Self, TypeAlias
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
entity: AdvancedChatAppGenerateEntity
type _GenerateEntityUnion = Annotated[
_GenerateEntityUnion: TypeAlias = Annotated[
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
Field(discriminator="type"),
]

View File

@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(

View File

@ -2,7 +2,7 @@ import logging
import time
from collections.abc import Generator
from threading import Thread
from typing import Any, cast
from typing import Any, Union, cast
from graphon.file import FileTransferMethod
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -72,12 +72,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
_task_state: EasyUITaskState
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@ -115,11 +117,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def process(
self,
) -> (
ChatbotAppBlockingResponse
| CompletionAppBlockingResponse
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
):
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
@ -134,7 +136,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
"""
Process blocking response.
:return:
@ -146,7 +148,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.model_dump()
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION:
response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@ -181,7 +183,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
"""
To stream response.
:return:

View File

@ -5,13 +5,14 @@ This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final, override
from typing import TYPE_CHECKING, cast, final
from graphon.enums import BuiltinNodeTypes
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
from graphon.nodes.base.node import Node
from typing_extensions import override
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available

View File

@ -8,9 +8,8 @@ associates with the node span.
"""
import logging
from contextvars import Token
from dataclasses import dataclass
from typing import cast, final, override
from typing import cast, final
from graphon.enums import BuiltinNodeTypes, NodeType
from graphon.graph_engine.layers import GraphEngineLayer
@ -18,6 +17,7 @@ from graphon.graph_events import GraphNodeEventBase
from graphon.nodes.base.node import Node
from opentelemetry import context as context_api
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
from typing_extensions import override
from configs import dify_config
from extensions.otel.parser import (
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: Token[context_api.Context]
token: object
@final

View File

@ -153,7 +153,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.get(UploadFile, id)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
if not upload_file:
return None
@ -171,7 +171,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = db.session.get(MessageFile, id)
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
# Check if message_file is not None
if message_file is not None:
@ -185,7 +185,7 @@ class DatasourceFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id)
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
@ -203,7 +203,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
return None, None

View File

@ -44,8 +44,7 @@ class HumanInputContent(BaseModel):
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = [
"ExecutionExtraContentDomainModel",

View File

@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
return session.execute(stmt).scalar_one_or_none()
@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
)
if exclude_id:
@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
credential_id=credential.id,
is_valid=True,
)
@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
return session.execute(stmt).scalars().first()
@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar() or 0
@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)

View File

@ -17,7 +17,7 @@ class CSVSanitizer:
"""
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
FORMULA_CHARS = frozenset(("=", "+", "-", "@", "\t", "\r"))
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
@classmethod
def sanitize_value(cls, value: Any) -> str:

View File

@ -2,13 +2,12 @@ import importlib.util
import logging
import sys
from types import ModuleType
from typing import AnyStr
logger = logging.getLogger(__name__)
def import_module_from_source[T: (str, bytes)](
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
) -> ModuleType:
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
"""
Importing a module from the source file directly
"""

View File

@ -2,6 +2,7 @@ import os
from collections import OrderedDict
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
from configs import dify_config
from core.tools.utils.yaml_utils import load_yaml_file_cached
@ -64,7 +65,10 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
return position_map
def is_filtered[T](
T = TypeVar("T")
def is_filtered(
include_set: set[str],
exclude_set: set[str],
data: T,
@ -93,11 +97,11 @@ def is_filtered[T](
return False
def sort_by_position_map[T](
def sort_by_position_map(
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
) -> list[T]:
):
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
@ -112,11 +116,11 @@ def sort_by_position_map[T](
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
def sort_to_dict_by_position_map[T](
def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[T],
name_func: Callable[[T], str],
) -> OrderedDict[str, T]:
):
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.

View File

@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
import logging
import time
from typing import Any
from typing import Any, TypeAlias
import httpx
from pydantic import TypeAdapter, ValidationError
@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
type Headers = dict[str, str]
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
Headers: TypeAlias = dict[str, str]
_HEADERS_ADAPTER = TypeAdapter(Headers)
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"

View File

@ -10,7 +10,7 @@ from typing import Any
from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import delete, func, select, update
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@ -78,7 +78,7 @@ class IndexingRunner:
continue
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@ -95,7 +95,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.get(Account, requeried_document.created_by)
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@ -137,24 +137,23 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id))
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
@ -168,7 +167,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.get(Account, requeried_document.created_by)
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@ -208,18 +207,17 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
documents = []
if document_segments:
@ -291,7 +289,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = db.session.get(Dataset, dataset_id)
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
@ -654,26 +652,24 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = db.session.get(Dataset, dataset_id)
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
@ -707,19 +703,17 @@ class IndexingRunner:
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
@ -740,17 +734,10 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = (
db.session.scalar(
select(func.count())
.select_from(DatasetDocument)
.where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True)
)
or 0
)
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedError()
document = db.session.get(DatasetDocument, document_id)
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedError()
@ -758,7 +745,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.commit()
@staticmethod
@ -766,9 +753,7 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
db.session.execute(
update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params)
)
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def _transform(

View File

@ -3,19 +3,13 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, TypedDict
from typing import Any
import orjson
from configs import dify_config
class IdentityDict(TypedDict, total=False):
tenant_id: str
user_id: str
user_type: str
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -90,7 +84,7 @@ class StructuredJSONFormatter(logging.Formatter):
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
@ -98,7 +92,7 @@ class StructuredJSONFormatter(logging.Formatter):
if not any([tenant_id, user_id, user_type]):
return None
identity: IdentityDict = {}
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:

View File

@ -3,7 +3,7 @@ import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, final
from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
@ -33,9 +33,9 @@ class _StatusError:
# Type aliases for better readability
type ReadQueue = queue.Queue[SessionMessage | Exception | None]
type WriteQueue = queue.Queue[SessionMessage | Exception | None]
type StatusQueue = queue.Queue[_StatusReady | _StatusError]
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
class SSETransport:

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
@ -9,12 +9,13 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT

View File

@ -260,12 +260,4 @@ def convert_input_form_to_parameters(
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
elif item.type == VariableEntityType.CHECKBOX:
parameters[item.variable]["type"] = "boolean"
elif item.type == VariableEntityType.JSON_OBJECT:
parameters[item.variable]["type"] = "object"
if item.json_schema:
for key in ("properties", "required", "additionalProperties"):
if key in item.json_schema:
parameters[item.variable][key] = item.json_schema[key]
return parameters, required

View File

@ -4,7 +4,7 @@ from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from datetime import timedelta
from types import TracebackType
from typing import Any, Self
from typing import Any, Generic, Self, TypeVar
from httpx import HTTPStatusError
from pydantic import BaseModel
@ -34,10 +34,16 @@ from core.mcp.types import (
logger = logging.getLogger(__name__)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
@ -54,7 +60,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
"""
request: ReceiveRequestT
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
_session: Any
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
@ -62,7 +68,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
):
self.request_id = request_id
@ -105,7 +111,7 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
self.completed = True
self._session.send_response(request_id=self.request_id, response=response)
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self):
"""Cancel this request and mark it as completed."""
@ -114,19 +120,21 @@ class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResul
self.completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session.send_response(
self._session._send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
class BaseSession[
SendRequestT: ClientRequest | ServerRequest,
SendNotificationT: ClientNotification | ServerNotification,
SendResultT: ClientResult | ServerResult,
ReceiveRequestT: ClientRequest | ServerRequest,
ReceiveNotificationT: ClientNotification | ServerNotification,
]:
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
@ -196,13 +204,13 @@ class BaseSession[
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request[T: BaseModel](
def send_request(
self,
request: SendRequestT,
result_type: type[T],
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata | None = None,
) -> T:
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
@ -291,7 +299,7 @@ class BaseSession[
)
self._write_stream.put(session_message)
def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@ -342,7 +350,7 @@ class BaseSession[
responder = RequestResponder[ReceiveRequestT, SendResultT](
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
@ -364,8 +372,8 @@ class BaseSession[
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification) # type: ignore[arg-type]
self._handle_incoming(notification) # type: ignore[arg-type]
self._received_notification(notification)
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
from pydantic.networks import AnyUrl, UrlConstraints
@ -31,7 +31,7 @@ ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
type AnyFunction = Callable[..., Any]
AnyFunction: TypeAlias = Callable[..., Any]
class RequestParams(BaseModel):
@ -68,7 +68,12 @@ class NotificationParams(BaseModel):
"""
class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
MethodT = TypeVar("MethodT", bound=str)
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
"""Base class for JSON-RPC requests."""
method: MethodT
@ -76,14 +81,14 @@ class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: st
model_config = ConfigDict(extra="allow")
class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
"""Base class for paginated requests,
matching the schema's PaginatedRequest interface."""
params: PaginatedRequestParams | None = None
class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
"""Base class for JSON-RPC notifications."""
method: MethodT
@ -731,7 +736,7 @@ class ResourceLink(Resource):
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
"""A content block that can be used in prompts and tool results."""
type Content = ContentBlock
Content: TypeAlias = ContentBlock
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""

View File

@ -16,13 +16,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
DEPLOYMENT_ENVIRONMENT,
)
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
@ -51,10 +45,10 @@ class TraceClient:
self.endpoint = endpoint
self.resource = Resource(
attributes={
service_attributes.SERVICE_NAME: service_name,
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
HOST_NAME: socket.gethostname(),
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ACS_ARMS_SERVICE_FEATURE: "genai_app",
}
)

View File

@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
@ -38,7 +38,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@ -135,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
exception_attributes.EXCEPTION_TYPE: exception_type,
exception_attributes.EXCEPTION_MESSAGE: exception_message,
exception_attributes.EXCEPTION_ESCAPED: False,
exception_attributes.EXCEPTION_STACKTRACE: error_string,
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
@ -470,7 +469,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)

View File

@ -1,19 +1,9 @@
import logging
import os
import uuid
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta
from graphon.enums import BuiltinNodeTypes
from langfuse import Langfuse
from langfuse.api import (
CreateGenerationBody,
CreateSpanBody,
IngestionEvent_GenerationCreate,
IngestionEvent_SpanCreate,
IngestionEvent_TraceCreate,
TraceBody,
)
from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@ -406,61 +396,18 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_span(langfuse_span_data=name_generation_span_data)
def _make_event_id(self) -> str:
return str(uuid.uuid4())
def _now_iso(self) -> str:
return datetime.now(UTC).isoformat()
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
try:
body = TraceBody(
id=data.get("id"),
name=data.get("name"),
user_id=data.get("user_id"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
session_id=data.get("session_id"),
version=data.get("version"),
release=data.get("release"),
tags=data.get("tags"),
public=data.get("public"),
)
event = IngestionEvent_TraceCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.trace(**format_trace_data)
logger.debug("LangFuse Trace created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
try:
body = CreateSpanBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
)
event = IngestionEvent_SpanCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.span(**format_span_data)
logger.debug("LangFuse Span created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
@ -471,45 +418,11 @@ class LangFuseDataTrace(BaseTraceInstance):
span.end(**format_span_data)
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
format_generation_data = (
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
)
try:
usage_data = data.pop("usage", None)
usage = None
if usage_data:
usage = Usage(
input=usage_data.get("input", 0) or 0,
output=usage_data.get("output", 0) or 0,
total=usage_data.get("total", 0) or 0,
unit=usage_data.get("unit"),
input_cost=usage_data.get("inputCost"),
output_cost=usage_data.get("outputCost"),
total_cost=usage_data.get("totalCost"),
)
body = CreateGenerationBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
model=data.get("model"),
model_parameters=data.get("model_parameters"),
input=data.get("input"),
output=data.get("output"),
usage=usage,
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
completion_start_time=data.get("completion_start_time"),
)
event = IngestionEvent_GenerationCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.generation(**format_generation_data)
logger.debug("LangFuse Generation created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
@ -530,7 +443,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def get_project_key(self):
try:
projects = self.langfuse_client.api.projects.get()
projects = self.langfuse_client.client.projects.get()
return projects.data[0].id
except Exception as e:
logger.debug("LangFuse get project key failed: %s", str(e))

View File

@ -1,3 +1,4 @@
import json
import logging
import os
from datetime import datetime, timedelta
@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance):
inputs = node.process_data # contains request URL
if not inputs:
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
inputs = json.loads(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
name=node.title,
@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {}
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == BuiltinNodeTypes.LLM:
@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance):
return {}, {}
try:
data = JSON_DICT_ADAPTER.validate_json(node.process_data)
except (ValueError, TypeError):
data = json.loads(node.process_data)
except (json.JSONDecodeError, TypeError):
return {}, {}
inputs = self._parse_prompts(data.get("prompts"))

Some files were not shown because too many files have changed in this diff Show More