Compare commits

..

1 Commits

Author SHA1 Message Date
6c085edd13 chore(deps): bump the python-packages group across 1 directory with 3 updates
Updates the requirements on [resend](https://github.com/resendlabs/resend-python), [sentry-sdk](https://github.com/getsentry/sentry-python) and [unstructured](https://github.com/Unstructured-IO/unstructured) to permit the latest version.

Updates `resend` to 2.27.0
- [Release notes](https://github.com/resendlabs/resend-python/releases)
- [Commits](https://github.com/resendlabs/resend-python/compare/v2.26.0...v2.27.0)

Updates `sentry-sdk` to 2.57.0
- [Release notes](https://github.com/getsentry/sentry-python/releases)
- [Changelog](https://github.com/getsentry/sentry-python/blob/master/CHANGELOG.md)
- [Commits](https://github.com/getsentry/sentry-python/compare/2.55.0...2.57.0)

Updates `unstructured` to 0.22.16
- [Release notes](https://github.com/Unstructured-IO/unstructured/releases)
- [Changelog](https://github.com/Unstructured-IO/unstructured/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Unstructured-IO/unstructured/compare/0.21.5...0.22.16)

---
updated-dependencies:
- dependency-name: resend
  dependency-version: 2.27.0
  dependency-type: direct:production
  dependency-group: python-packages
- dependency-name: sentry-sdk
  dependency-version: 2.57.0
  dependency-type: direct:production
  dependency-group: python-packages
- dependency-name: unstructured
  dependency-version: 0.22.16
  dependency-type: direct:production
  dependency-group: python-packages
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-06 01:16:27 +00:00
522 changed files with 5469 additions and 14498 deletions

9
.github/labeler.yml vendored
View File

@ -1,10 +1,3 @@
web:
- changed-files:
- any-glob-to-any-file:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- any-glob-to-any-file: 'web/**'

View File

@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@ -1,82 +0,0 @@
import { execFileSync } from 'node:child_process'
import fs from 'node:fs'
import path from 'node:path'
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
outputPath,
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)

View File

@ -39,11 +39,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'

View File

@ -65,7 +65,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
@ -130,7 +130,7 @@ jobs:
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@4907a6ddec9925e35a0a9e82d7399ccc52663121 # v4.1.0
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}

View File

@ -8,11 +8,9 @@ on:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:

View File

@ -65,11 +65,9 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
@ -79,11 +77,9 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'packages/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.npmrc'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'

View File

@ -77,11 +77,9 @@ jobs:
with:
files: |
web/**
packages/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.npmrc
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@ -151,7 +149,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@9e863354e3ff62e0727d37183162c4a88873df41 # v8.6.0
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@ -9,7 +9,6 @@ on:
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@ -68,7 +68,89 @@ jobs:
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() {
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
}
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
@ -158,7 +240,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
@ -188,7 +270,7 @@ jobs:
Tool rules:
- Use Read for repository files.
- Use Edit for JSON updates.
- Use Bash only for `vp`.
- Use Bash only for `pnpm`.
- Do not use Bash for `git`, `gh`, or branch management.
Required execution plan:
@ -210,7 +292,7 @@ jobs:
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
4. Run a scoped pre-check before editing:
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations.
- For every target language and scoped file:
@ -218,19 +300,19 @@ jobs:
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys.
- UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching.
6. Verify only the edited files.
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests.
claude_args: |
--max-turns 120
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
- name: Prepare branch metadata
id: pr_meta
@ -272,7 +354,6 @@ jobs:
- name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
@ -321,8 +402,8 @@ jobs:
'',
'## Verification',
'',
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
'',
'## Notes',
'',

View File

@ -42,7 +42,88 @@ jobs:
fi
export BASE_SHA HEAD_SHA CHANGED_FILES
node .github/scripts/generate-i18n-changes.mjs
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"

View File

@ -36,7 +36,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@ -81,8 +81,8 @@ if $web_modified; then
if $web_ts_modified; then
echo "Running TypeScript type-check:tsgo"
if ! npm run type-check:tsgo; then
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
if ! pnpm run type-check:tsgo; then
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
exit 1
fi
else
@ -90,10 +90,36 @@ if $web_modified; then
fi
echo "Running knip"
if ! npm run knip; then
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
if ! pnpm run knip; then
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
exit 1
fi
echo "Running unit tests check"
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
if [ -n "$modified_files" ]; then
for file in $modified_files; do
test_file="${file%.*}.spec.ts"
echo "Checking for test file: $test_file"
# check if the test file exists
if [ -f "../$test_file" ]; then
echo "Detected changes in $file, running corresponding unit tests..."
pnpm run test "../$test_file"
if [ $? -ne 0 ]; then
echo "Unit tests failed. Please fix the errors before committing."
exit 1
fi
echo "Unit tests for $file passed."
else
echo "Warning: $file does not have a corresponding test file."
fi
done
echo "All unit tests for modified web/utils files have passed."
fi
cd ../
fi

View File

@ -71,13 +71,6 @@ REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
REDIS_RETRY_RETRIES=3
REDIS_RETRY_BACKOFF_BASE=1.0
REDIS_RETRY_BACKOFF_CAP=10.0
REDIS_SOCKET_TIMEOUT=5.0
REDIS_SOCKET_CONNECT_TIMEOUT=5.0
REDIS_HEALTH_CHECK_INTERVAL=30
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis

View File

@ -1,18 +0,0 @@
# This module provides a lightweight Celery instance for use in Docker health checks.
# Unlike celery_entrypoint.py, this does NOT import app.py and therefore avoids
# initializing all Flask extensions (DB, Redis, storage, blueprints, etc.).
# Using this module keeps the health check fast and low-cost.
from celery import Celery
from configs import dify_config
from extensions.ext_celery import get_celery_broker_transport_options, get_celery_ssl_options
celery = Celery(broker=dify_config.CELERY_BROKER_URL)
broker_transport_options = get_celery_broker_transport_options()
if broker_transport_options:
celery.conf.update(broker_transport_options=broker_transport_options)
ssl_options = get_celery_ssl_options()
if ssl_options:
celery.conf.update(broker_use_ssl=ssl_options)

View File

@ -1,7 +1,7 @@
import datetime
import logging
import time
from typing import TypedDict
from typing import Any
import click
import sqlalchemy as sa
@ -503,19 +503,7 @@ def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]:
return [row[0] for row in result]
class _AppOrphanCounts(TypedDict):
variables: int
files: int
class OrphanedDraftVariableStatsDict(TypedDict):
total_orphaned_variables: int
total_orphaned_files: int
orphaned_app_count: int
orphaned_by_app: dict[str, _AppOrphanCounts]
def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
def _count_orphaned_draft_variables() -> dict[str, Any]:
"""
Count orphaned draft variables by app, including associated file counts.
@ -538,7 +526,7 @@ def _count_orphaned_draft_variables() -> OrphanedDraftVariableStatsDict:
with db.engine.connect() as conn:
result = conn.execute(sa.text(variables_query))
orphaned_by_app: dict[str, _AppOrphanCounts] = {}
orphaned_by_app = {}
total_files = 0
for row in result:

View File

@ -117,37 +117,6 @@ class RedisConfig(BaseSettings):
default=None,
)
REDIS_RETRY_RETRIES: NonNegativeInt = Field(
description="Maximum number of retries per Redis command on "
"transient failures (ConnectionError, TimeoutError, socket.timeout)",
default=3,
)
REDIS_RETRY_BACKOFF_BASE: PositiveFloat = Field(
description="Base delay in seconds for exponential backoff between retries",
default=1.0,
)
REDIS_RETRY_BACKOFF_CAP: PositiveFloat = Field(
description="Maximum backoff delay in seconds between retries",
default=10.0,
)
REDIS_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis read/write operations",
default=5.0,
)
REDIS_SOCKET_CONNECT_TIMEOUT: PositiveFloat | None = Field(
description="Socket timeout in seconds for Redis connection establishment",
default=5.0,
)
REDIS_HEALTH_CHECK_INTERVAL: NonNegativeInt = Field(
description="Interval in seconds between Redis connection health checks (0 to disable)",
default=30,
)
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
@classmethod
def _empty_string_to_none_for_max_conns(cls, v):

View File

@ -1,79 +0,0 @@
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from libs.helper import UUIDStrOrEmpty
# --- Conversation schemas ---
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
# --- Message schemas ---
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
# --- Saved message schemas ---
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
# --- Workflow schemas ---
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@ -2,7 +2,6 @@ import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import cast
from flask import request
from flask_restx import Resource
@ -18,7 +17,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService, LangContentDict
from services.billing_service import BillingService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -329,7 +328,7 @@ class UpsertNotificationApi(Resource):
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[cast(LangContentDict, c.model_dump()) for c in payload.contents],
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,

View File

@ -7,7 +7,7 @@ from flask import request
from flask_restx import Resource
from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest
@ -26,11 +26,9 @@ from controllers.console.wraps import (
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
@ -43,7 +41,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
NotionIcon,
NotionInfo,
NotionPage,
PreProcessingRule,
RerankingModel,
Rule,
Segmentation,
WebsiteInfo,
WeightKeywordSetting,
WeightModel,
@ -154,6 +155,16 @@ class AppTracePayload(BaseModel):
type JSONValue = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = current_user
@ -92,13 +92,11 @@ class AppImportApi(Resource):
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports/<string:import_id>/confirm")

View File

@ -8,7 +8,6 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, func, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@ -60,8 +59,10 @@ class ChatMessagesQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
@field_validator("message_id")
@classmethod

View File

@ -14,7 +14,6 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow_run import workflow_run_node_execution_model
@ -143,6 +142,10 @@ class PublishWorkflowPayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class ConvertToWorkflowPayload(BaseModel):
name: str | None = None
icon_type: str | None = None
@ -150,6 +153,18 @@ class ConvertToWorkflowPayload(BaseModel):
icon_background: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str

View File

@ -384,27 +384,24 @@ class VariableApi(Resource):
new_value = None
if raw_value is not None:
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=app_model.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -66,13 +66,13 @@ class WebhookTriggerApi(Resource):
with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
.where(
WorkflowWebhookTrigger.app_id == app_model.id,
WorkflowWebhookTrigger.node_id == node_id,
)
.limit(1)
.first()
)
if not webhook_trigger:

View File

@ -3,7 +3,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@ -20,18 +20,35 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password
from libs.password import hash_password, valid_password
from services.account_service import AccountService, TenantService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr = Field(...)
language: str | None = Field(default=None)
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr = Field(...)
code: str = Field(...)
token: str = Field(...)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")

View File

@ -1,3 +1,5 @@
from typing import Any
import flask_login
from flask import make_response, request
from flask_restx import Resource
@ -40,9 +42,8 @@ from libs.token import (
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
@ -50,7 +51,9 @@ from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class LoginPayload(LoginPayloadBase):
class LoginPayload(BaseModel):
email: EmailStr = Field(..., description="Email address")
password: str = Field(..., description="Password")
remember_me: bool = Field(default=False, description="Remember me flag")
invite_token: str | None = Field(default=None, description="Invitation token")
@ -98,7 +101,7 @@ class LoginApi(Resource):
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: InvitationDetailDict | None = None
invitation_data: dict[str, Any] | None = None
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:

View File

@ -3,7 +3,6 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@ -87,8 +86,8 @@ class CustomizedPipelineTemplateApi(Resource):
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
template = session.scalar(
select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).limit(1)
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@ -223,27 +223,24 @@ class RagPipelineVariableApi(Resource):
new_value = None
if raw_value is not None:
match variable.value_type:
case SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
case _:
pass
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(
mapping=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(
mappings=raw_value,
tenant_id=pipeline.tenant_id,
access_controller=_file_access_controller,
)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()

View File

@ -83,13 +83,11 @@ class RagPipelineImportApi(Resource):
# Return appropriate status code based on result
status = result.status
match status:
case ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
case ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
case ImportStatus.COMPLETED | ImportStatus.COMPLETED_WITH_WARNINGS:
return result.model_dump(mode="json"), 200
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
@console_ns.route("/rag/pipelines/imports/<string:import_id>/confirm")

View File

@ -10,7 +10,6 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import DefaultBlockConfigQuery, WorkflowListQuery, WorkflowUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
@ -95,6 +94,22 @@ class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
original_document_id: str | None = None
class DefaultBlockConfigQuery(BaseModel):
q: str | None = None
class WorkflowListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=10, ge=1, le=100)
user_id: str | None = None
named_only: bool = False
class WorkflowUpdatePayload(BaseModel):
marked_name: str | None = Field(default=None, max_length=20)
marked_comment: str | None = Field(default=None, max_length=100)
class NodeIdQuery(BaseModel):
node_id: str

View File

@ -2,10 +2,10 @@ import logging
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
AppUnavailableError,
@ -32,6 +32,14 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(console_ns, TextToAudioPayload)

View File

@ -1,11 +1,10 @@
from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
@ -33,6 +32,18 @@ class ConversationListQuery(BaseModel):
pinned: bool | None = None
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,10 +3,9 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console.app.error import (
AppMoreLikeThisDisabledError,
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
@ -44,6 +44,17 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = None
content: str | None = None
class MoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"]

View File

@ -1,18 +1,28 @@
from flask import request
from pydantic import TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,10 +1,11 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_model
from controllers.console.app.error import (
CompletionRequestError,
@ -33,6 +34,12 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
register_schema_model(console_ns, WorkflowRunPayload)

View File

@ -168,13 +168,12 @@ class ConsoleWorkflowEventsApi(Resource):
else:
msg_generator = MessageGenerator()
generator: BaseAppGenerator
match app.mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@ -1,5 +1,3 @@
from typing import TypedDict
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -13,21 +11,6 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
lang: str
title: str
subtitle: str
body: str
title_pic_url: str
class NotificationResponseDict(TypedDict):
should_show: bool
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
@ -62,30 +45,28 @@ class NotificationApi(Resource):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
response: NotificationResponseDict
if not result.get("shouldShow"):
response = {"should_show": False, "notifications": []}
return response, 200
return {"should_show": False, "notifications": []}, 200
lang = current_user.interface_language or _FALLBACK_LANG
notifications: list[NotificationItemDict] = []
notifications = []
for notification in result.get("notifications") or []:
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
notifications.append(item)
notifications.append(
{
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"subtitle": lang_content.get("subtitle", ""),
"body": lang_content.get("body", ""),
"title_pic_url": lang_content.get("titlePicUrl", ""),
}
)
response = {"should_show": bool(notifications), "notifications": notifications}
return response, 200
return {"should_show": bool(notifications), "notifications": notifications}, 200
@console_ns.route("/notification/dismiss")

View File

@ -9,14 +9,7 @@ from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from libs.login import current_account_with_tenant, login_required
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
@ -32,19 +25,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace):
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: TagType = Field(description="Tag type")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagListQueryParam(BaseModel):
@ -89,7 +82,7 @@ class TagListApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type))
tag = TagService.save_tags(payload.model_dump())
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -110,7 +103,7 @@ class TagUpdateDeleteApi(Resource):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id)
tag = TagService.update_tags(payload.model_dump(), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -143,9 +136,7 @@ class TagBindingCreateApi(Resource):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type)
)
TagService.save_tag_binding(payload.model_dump())
return {"result": "success"}, 200
@ -163,8 +154,6 @@ class TagBindingDeleteApi(Resource):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type)
)
TagService.delete_tag_binding(payload.model_dump())
return {"result": "success"}, 200

View File

@ -1,7 +1,6 @@
from collections.abc import Callable
from functools import wraps
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
@ -22,12 +21,12 @@ def plugin_permission_required(
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
permission = session.scalar(
select(TenantPluginPermission)
permission = (
session.query(TenantPluginPermission)
.where(
TenantPluginPermission.tenant_id == tenant_id,
)
.limit(1)
.first()
)
if not permission:

View File

@ -28,7 +28,7 @@ from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@ -240,10 +240,8 @@ class CustomConfigWorkspaceApi(Resource):
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
custom_config_dict: TenantCustomConfigDict = {
"remove_webapp_brand": args.remove_webapp_brand
if args.remove_webapp_brand is not None
else tenant.custom_config_dict.get("remove_webapp_brand", False),
custom_config_dict = {
"remove_webapp_brand": args.remove_webapp_brand,
"replace_webapp_logo": args.replace_webapp_logo
if args.replace_webapp_logo is not None
else tenant.custom_config_dict.get("replace_webapp_logo"),

View File

@ -9,7 +9,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.wraps import setup_required
@ -55,7 +55,7 @@ class EnterpriseAppDSLImport(Resource):
account.set_tenant_id(workspace_id)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
dsl_service = AppDslService(session)
result = dsl_service.import_app(
account=account,
@ -64,6 +64,7 @@ class EnterpriseAppDSLImport(Resource):
name=args.name,
description=args.description,
)
session.commit()
if result.status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400

View File

@ -4,7 +4,6 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.common.schema import register_schema_model
@ -81,11 +80,11 @@ class MCPAppApi(Resource):
def _get_mcp_server_and_app(self, server_code: str, session: Session) -> tuple[AppMCPServer, App]:
"""Get and validate MCP server and app in one query session"""
mcp_server = session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1))
mcp_server = session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not mcp_server:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server Not Found")
app = session.scalar(select(App).where(App.id == mcp_server.app_id).limit(1))
app = session.query(App).where(App.id == mcp_server.app_id).first()
if not app:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "App Not Found")
@ -191,12 +190,12 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
return session.scalar(
select(EndUser)
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
.where(EndUser.session_id == mcp_server_id)
.where(EndUser.type == "mcp")
.limit(1)
.first()
)
def _create_end_user(

View File

@ -2,12 +2,11 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, NotFound
import services
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@ -35,6 +34,18 @@ class ConversationListQuery(BaseModel):
)
class ConversationRenamePayload(BaseModel):
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")

View File

@ -1,4 +1,5 @@
import logging
from typing import Literal
from flask import request
from flask_restx import Resource
@ -6,7 +7,6 @@ from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
@ -14,6 +14,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
@ -26,6 +27,17 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class FeedbackListQuery(BaseModel):
page: int = Field(default=1, ge=1, description="Page number")
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")

View File

@ -1,5 +1,5 @@
import logging
from typing import Literal
from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
@ -11,7 +11,6 @@ from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@ -47,7 +46,9 @@ from services.workflow_app_service import WorkflowAppService
logger = logging.getLogger(__name__)
class WorkflowRunPayload(WorkflowRunPayloadBase):
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None

View File

@ -22,17 +22,10 @@ from fields.tag_fields import DataSetTag
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.enums import TagType
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import (
SaveTagPayload,
TagBindingCreatePayload,
TagBindingDeletePayload,
TagService,
UpdateTagPayload,
)
from services.tag_service import TagService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@ -520,7 +513,7 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
@ -543,8 +536,9 @@ class DatasetTagsApi(DatasetApiResource):
raise Forbidden()
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
params = {"name": payload.name, "type": "knowledge"}
tag_id = payload.tag_id
tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id)
tag = TagService.update_tags(params, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@ -591,9 +585,7 @@ class DatasetTagBindingApi(DatasetApiResource):
raise Forbidden()
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
TagService.save_tag_binding(
TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
return "", 204
@ -617,9 +609,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
raise Forbidden()
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
TagService.delete_tag_binding(
TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE)
)
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
return "", 204

View File

@ -31,7 +31,6 @@ from controllers.service_api.wraps import (
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError
from core.rag.entities import PreProcessingRule, Rule, Segmentation
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
@ -41,8 +40,11 @@ from models.enums import SegmentStatus
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import (
KnowledgeConfig,
PreProcessingRule,
ProcessRule,
RetrievalModel,
Rule,
Segmentation,
)
from services.file_service import FileService
from services.summary_index_service import SummaryIndexService

View File

@ -4,23 +4,13 @@ Serialization helpers for Service API knowledge pipeline endpoints.
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import UploadFile
class UploadFileDict(TypedDict):
id: str
name: str
size: int
extension: str
mime_type: str | None
created_by: str
created_at: str | None
def serialize_upload_file(upload_file: UploadFile) -> UploadFileDict:
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
return {
"id": upload_file.id,
"name": upload_file.name,

View File

@ -3,11 +3,10 @@ import logging
from flask import request
from flask_restx import fields, marshal_with
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import field_validator
from pydantic import BaseModel, field_validator
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@ -35,7 +34,12 @@ from services.errors.audio import (
from ..common.schema import register_schema_models
class TextToAudioPayload(TextToAudioPayloadBase):
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:

View File

@ -1,11 +1,10 @@
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import ConversationRenamePayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotChatAppError
@ -38,6 +37,18 @@ class ConversationListQuery(BaseModel):
return uuid_value(value)
class ConversationRenamePayload(BaseModel):
name: str | None = None
auto_generate: bool = False
@model_validator(mode="after")
def validate_name_requirement(self):
if not self.auto_generate:
if self.name is None or not self.name.strip():
raise ValueError("name is required when auto_generate is false")
return self
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)

View File

@ -3,6 +3,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
@ -18,15 +19,33 @@ from controllers.console.error import EmailSendIpLimitError
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import extract_remote_ip
from libs.password import hash_password
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService
from services.entities.auth_entities import (
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordSendPayload,
)
class ForgotPasswordSendPayload(BaseModel):
email: EmailStr
language: str | None = None
class ForgotPasswordCheckPayload(BaseModel):
email: EmailStr
code: str
token: str = Field(min_length=1)
class ForgotPasswordResetPayload(BaseModel):
token: str = Field(min_length=1)
new_password: str
password_confirm: str
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
register_schema_models(web_ns, ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload)

View File

@ -29,11 +29,13 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
class LoginPayload(LoginPayloadBase):
class LoginPayload(BaseModel):
email: EmailStr
password: str
@field_validator("password")
@classmethod
def validate_password(cls, value: str) -> str:

View File

@ -6,7 +6,6 @@ from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@ -54,6 +53,11 @@ class MessageListQuery(BaseModel):
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
content: str | None = Field(default=None, description="Feedback content")
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@ -138,15 +138,12 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
match auth_type:
case WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
case WebAppAuthType.EXTERNAL:
if user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
case WebAppAuthType.INTERNAL:
if user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
if auth_type == WebAppAuthType.PUBLIC:
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
elif auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
raise WebAppAuthRequiredError("Please login as external user.")
elif auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
raise WebAppAuthRequiredError("Please login as internal user.")
end_user = None
if end_user_id:

View File

@ -1,17 +1,27 @@
from flask import request
from pydantic import TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
class SavedMessageListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class SavedMessageCreatePayload(BaseModel):
message_id: UUIDStrOrEmpty
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)

View File

@ -1,10 +1,11 @@
import logging
from typing import Any
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
from controllers.common.controller_schemas import WorkflowRunPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@ -29,6 +30,12 @@ from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
class WorkflowRunPayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
logger = logging.getLogger(__name__)
register_schema_models(web_ns, WorkflowRunPayload)

View File

@ -72,13 +72,12 @@ class WorkflowEventsApi(WebApiResource):
app_mode = AppMode.value_of(app_model.mode)
msg_generator = MessageGenerator()
generator: BaseAppGenerator
match app_mode:
case AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
case AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
case _:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"

View File

@ -79,18 +79,21 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
content = ""
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
content += f"Final Answer: {unit.agent_response}"
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
content += f"Thought: {unit.thought}\n\n"
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
content += f"Action: {unit.action_str}\n\n"
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
content += f"Observation: {unit.observation}\n\n"
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [AssistantPromptMessage(content=content)]
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])

View File

@ -5,10 +5,6 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class FeatureToggleDict(TypedDict):
enabled: bool
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
@ -20,12 +16,12 @@ class SystemParametersDict(TypedDict):
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: FeatureToggleDict
speech_to_text: FeatureToggleDict
text_to_speech: FeatureToggleDict
retriever_resource: FeatureToggleDict
annotation_reply: FeatureToggleDict
more_like_this: FeatureToggleDict
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@ -8,7 +9,6 @@ from graphon.variables.input_entities import VariableEntity as WorkflowVariableE
from pydantic import BaseModel, Field
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
from core.rag.entities import MetadataFilteringCondition
from models.model import AppMode
@ -111,6 +111,31 @@ class ExternalDataVariableEntity(BaseModel):
config: dict[str, Any] = Field(default_factory=dict)
SupportedComparisonOperator = Literal[
# for string or array
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
# for number
"=",
"",
">",
"<",
"",
"",
# for time
"before",
"after",
]
class ModelConfig(BaseModel):
provider: str
name: str
@ -118,6 +143,25 @@ class ModelConfig(BaseModel):
completion_params: dict[str, Any] = Field(default_factory=dict)
class Condition(BaseModel):
"""
Condition detail
"""
name: str
comparison_operator: SupportedComparisonOperator
value: str | Sequence[str] | None | int | float = None
class MetadataFilteringCondition(BaseModel):
"""
Metadata Filtering Condition.
"""
logical_operator: Literal["and", "or"] | None = "and"
conditions: list[Condition] | None = Field(default=None, deprecated=True)
class DatasetRetrieveConfigEntity(BaseModel):
"""
Dataset Retrieve Config Entity.

View File

@ -10,7 +10,7 @@ from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import Variable
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -363,7 +363,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
:return: List of conversation variables ready for use
"""
with sessionmaker(bind=db.engine).begin() as session:
with Session(db.engine) as session:
existing_variables = self._load_existing_conversation_variables(session)
if not existing_variables:
@ -376,6 +376,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# Convert to Variable objects for use in the workflow
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:

View File

@ -16,7 +16,7 @@ from graphon.model_runtime.utils.encoders import jsonable_encoder
from graphon.nodes import BuiltinNodeTypes
from graphon.runtime import GraphRuntimeState
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -328,8 +328,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -107,13 +107,13 @@ class AppGenerateResponseConverter(ABC):
return metadata
@classmethod
def _error_to_stream_response(cls, e: Exception) -> dict[str, Any]:
def _error_to_stream_response(cls, e: Exception):
"""
Error to stream response.
:param e: exception
:return:
"""
error_responses: dict[type[Exception], dict[str, Any]] = {
error_responses = {
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
@ -127,7 +127,7 @@ class AppGenerateResponseConverter(ABC):
}
# Determine the response based on the type of exception
data: dict[str, Any] | None = None
data = None
for k, v in error_responses.items():
if isinstance(e, k):
data = v

View File

@ -7,7 +7,7 @@ from typing import Union
from graphon.entities import WorkflowStartReason
from graphon.enums import WorkflowExecutionStatus
from graphon.runtime import GraphRuntimeState
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -252,8 +252,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
@contextmanager
def _database_session(self):
"""Context manager for database sessions."""
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
yield session
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""

View File

@ -66,7 +66,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.system_variables import (
build_bootstrap_variables,

View File

@ -10,7 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChun
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
class QueueEvent(StrEnum):

View File

@ -9,7 +9,7 @@ from graphon.nodes.human_input.entities import FormInput, UserAction
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
class AnnotationReplyAccount(BaseModel):

View File

@ -1,6 +1,6 @@
from graphon.model_runtime.entities.llm_entities import LLMUsage
from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
@ -57,37 +57,37 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
match system_configuration.current_quota_type:
case ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
case ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
case ProviderQuotaType.FREE:
with sessionmaker(bind=db.engine).begin() as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_limit > Provider.quota_used,
)
session.execute(stmt)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@ -12,7 +12,7 @@ from graphon.model_runtime.entities.message_entities import (
)
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -266,8 +266,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
event = message.event
if isinstance(event, QueueErrorEvent):
with sessionmaker(bind=db.engine).begin() as session:
with Session(db.engine) as session:
err = self.handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self.error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -287,9 +288,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
answer=output_moderation_answer
)
with sessionmaker(bind=db.engine).begin() as session:
with Session(db.engine) as session:
# Save message
self._save_message(session=session, trace_manager=trace_manager)
session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
@ -507,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = session.scalar(
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:

View File

@ -40,44 +40,41 @@ def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, Upl
size = 0
extension = ""
match message_file.transfer_method:
case FileTransferMethod.REMOTE_URL:
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
case FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
case FileTransferMethod.TOOL_FILE if message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
case FileTransferMethod.TOOL_FILE | FileTransferMethod.DATASOURCE_FILE:
pass
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""

View File

@ -6,7 +6,7 @@ from sqlalchemy import select, update
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities import RetrievalSourceMetadata
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
from extensions.ext_database import db

View File

@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = session.scalar(
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@ -1,3 +1,22 @@
from core.tools.entities.common_entities import I18nObject, I18nObjectDict
from pydantic import BaseModel, Field, model_validator
__all__ = ["I18nObject", "I18nObjectDict"]
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -9,7 +9,7 @@ from yarl import URL
from configs import dify_config
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities import OAuthSchema
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,

View File

@ -1,8 +1 @@
from core.entities.plugin_credential_type import PluginCredentialType
DEFAULT_PLUGIN_ID = "langgenius"
__all__ = [
"DEFAULT_PLUGIN_ID",
"PluginCredentialType",
]

View File

@ -1,9 +0,0 @@
import enum
class PluginCredentialType(enum.Enum):
MODEL = 0 # must be 0 for API contract compatibility
TOOL = 1 # must be 1 for API contract compatibility
def to_number(self):
return self.value

View File

@ -22,7 +22,6 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.entities import PluginCredentialType
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
CustomConfiguration,
@ -47,6 +46,7 @@ from models.provider import (
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,7 @@
Credential utility functions for checking credential existence and policy compliance.
"""
from core.entities import PluginCredentialType
from services.enterprise.plugin_manager_service import PluginCredentialType
def is_credential_exists(credential_id: str, credential_type: "PluginCredentialType") -> bool:

View File

@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, TypedDict, cast
from typing import Protocol, cast
import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey
@ -49,17 +49,6 @@ class WorkflowServiceInterface(Protocol):
pass
class CodeGenerateResultDict(TypedDict):
code: str
language: str
error: str
class StructuredOutputResultDict(TypedDict):
output: str
error: str
class LLMGenerator:
@classmethod
def generate_conversation_name(
@ -304,7 +293,7 @@ class LLMGenerator:
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
) -> CodeGenerateResultDict:
):
if args.code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
@ -373,9 +362,7 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(
cls, tenant_id: str, args: RuleStructuredOutputPayload
) -> StructuredOutputResultDict:
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
model_manager = ModelManager.for_tenant(tenant_id=tenant_id)
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
@ -467,7 +454,7 @@ class LLMGenerator:
):
session = db.session()
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
app: App | None = session.query(App).where(App.id == flow_id).first()
if not app:
raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@ -6,7 +6,6 @@ import logging
import flask
from core.logging.context import get_request_id, get_trace_id
from core.logging.structured_formatter import IdentityDict
class TraceContextFilter(logging.Filter):
@ -61,7 +60,7 @@ class IdentityContextFilter(logging.Filter):
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> IdentityDict:
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
@ -78,7 +77,7 @@ class IdentityContextFilter(logging.Filter):
from models import Account
from models.model import EndUser
identity: IdentityDict = {}
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, NotRequired, TypedDict, cast
from typing import Any, cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@ -15,17 +15,6 @@ from services.app_generate_service import AppGenerateService
logger = logging.getLogger(__name__)
class ToolParameterSchemaDict(TypedDict):
type: str
properties: dict[str, Any]
required: list[str]
class ToolArgumentsDict(TypedDict):
query: NotRequired[str]
inputs: dict[str, Any]
def handle_mcp_request(
app: App,
request: mcp_types.ClientRequest,
@ -130,7 +119,7 @@ def handle_list_tools(
mcp_types.Tool(
name=app_name,
description=description,
inputSchema=cast(dict[str, Any], parameter_schema),
inputSchema=parameter_schema,
)
],
)
@ -165,7 +154,7 @@ def build_parameter_schema(
app_mode: str,
user_input_form: list[VariableEntity],
parameters_dict: dict[str, str],
) -> ToolParameterSchemaDict:
) -> dict[str, Any]:
"""Build parameter schema for the tool"""
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
@ -185,18 +174,17 @@ def build_parameter_schema(
}
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> ToolArgumentsDict:
def prepare_tool_arguments(app: App, arguments: dict[str, Any]) -> dict[str, Any]:
"""Prepare arguments based on app mode"""
match app.mode:
case AppMode.WORKFLOW:
return {"inputs": arguments}
case AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
case _:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
if app.mode == AppMode.WORKFLOW:
return {"inputs": arguments}
elif app.mode == AppMode.COMPLETION:
return {"query": "", "inputs": arguments}
else:
# Chat modes - create a copy to avoid modifying original dict
args_copy = arguments.copy()
query = args_copy.pop("query", "")
return {"query": query, "inputs": args_copy}
def extract_answer_from_response(app: App, response: Any) -> str:
@ -230,13 +218,17 @@ def process_streaming_response(response: RateLimitGenerator) -> str:
def process_mapping_response(app: App, response: Mapping) -> str:
"""Process mapping response based on app mode"""
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.COMPLETION | AppMode.CHAT | AppMode.AGENT_CHAT:
return response.get("answer", "")
case AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
case _:
raise ValueError("Invalid app mode: " + str(app.mode))
if app.mode in {
AppMode.ADVANCED_CHAT,
AppMode.COMPLETION,
AppMode.CHAT,
AppMode.AGENT_CHAT,
}:
return response.get("answer", "")
elif app.mode == AppMode.WORKFLOW:
return json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode: " + str(app.mode))
def convert_input_form_to_parameters(

View File

@ -17,7 +17,6 @@ from graphon.model_runtime.model_providers.__base.text_embedding_model import Te
from graphon.model_runtime.model_providers.__base.tts_model import TTSModel
from configs import dify_config
from core.entities import PluginCredentialType
from core.entities.embedding_type import EmbeddingInputType
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import ModelLoadBalancingConfiguration
@ -26,6 +25,7 @@ from core.plugin.impl.model_runtime_factory import create_plugin_provider_manage
from core.provider_manager import ProviderManager
from extensions.ext_redis import redis_client
from models.provider import ProviderType
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any
from graphon.entities import WorkflowNodeExecution
from graphon.enums import WorkflowNodeExecutionStatus
@ -56,22 +56,10 @@ def create_links_from_trace_id(trace_id: str | None) -> list[Link]:
return links
class RetrievalDocumentMetadataDict(TypedDict):
dataset_id: Any
doc_id: Any
document_id: Any
class RetrievalDocumentDict(TypedDict):
content: str
metadata: RetrievalDocumentMetadataDict
score: Any
def extract_retrieval_documents(documents: list[Document]) -> list[RetrievalDocumentDict]:
documents_data: list[RetrievalDocumentDict] = []
def extract_retrieval_documents(documents: list[Document]) -> list[dict[str, Any]]:
documents_data = []
for document in documents:
document_data: RetrievalDocumentDict = {
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
@ -95,7 +83,7 @@ def create_common_span_attributes(
framework: str = DEFAULT_FRAMEWORK_NAME,
inputs: str = "",
outputs: str = "",
) -> dict[str, str]:
) -> dict[str, Any]:
return {
GEN_AI_SESSION_ID: session_id,
GEN_AI_USER_ID: user_id,

View File

@ -56,10 +56,8 @@ class BaseTraceInstance(ABC):
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -241,10 +241,8 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -72,18 +72,17 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
conversation_id = conversation_id or ""
match app.mode:
case AppMode.ADVANCED_CHAT | AppMode.AGENT_CHAT | AppMode.CHAT:
if not query:
raise ValueError("missing query")
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT, AppMode.CHAT}:
if not query:
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
case AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
case AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
case _:
raise ValueError("unexpected app type")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
elif app.mode == AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode == AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
raise ValueError("unexpected app type")
@classmethod
def invoke_chat_app(
@ -99,61 +98,60 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
invoke chat app
"""
match app.mode:
case AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
case AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
case _:
if app.mode == AppMode.ADVANCED_CHAT:
workflow = app.workflow
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
elif app.mode == AppMode.CHAT:
return ChatAppGenerator().generate(
app_model=app,
user=user,
args={
"inputs": inputs,
"query": query,
"files": files,
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
)
else:
raise ValueError("unexpected app type")
@classmethod
def invoke_workflow_app(
cls,

View File

@ -1,5 +0,0 @@
from core.plugin.entities.oauth import OAuthSchema
__all__ = [
"OAuthSchema",
]

View File

@ -1,3 +1,5 @@
from collections.abc import Sequence
from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
@ -8,12 +10,12 @@ class OAuthSchema(BaseModel):
OAuth schema
"""
client_schema: list[ProviderConfig] = Field(
client_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="client schema like client_id, client_secret, etc.",
)
credentials_schema: list[ProviderConfig] = Field(
credentials_schema: Sequence[ProviderConfig] = Field(
default_factory=list,
description="credentials schema like access_token, refresh_token, etc.",
)

View File

@ -209,10 +209,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse,
params={
"plugin_unique_identifier": plugin_unique_identifier,
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
},
params={"plugin_unique_identifier": plugin_unique_identifier},
)
def fetch_plugin_installation_by_ids(

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -14,7 +15,6 @@ from graphon.model_runtime.entities.provider_entities import (
ProviderEntity,
)
from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@ -58,8 +58,6 @@ from services.feature_service import FeatureService
if TYPE_CHECKING:
from graphon.model_runtime.runtime import ModelRuntime
_credentials_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class ProviderManager:
"""
@ -877,8 +875,8 @@ class ProviderManager:
return {"openai_api_key": encrypted_config}
try:
credentials = _credentials_adapter.validate_json(encrypted_config)
except (ValueError, JSONDecodeError):
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
@ -961,37 +959,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
match provider_quota.quota_type:
case ProviderQuotaType.TRIAL if trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case ProviderQuotaType.PAID if paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
case _:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)
@ -1018,7 +1015,7 @@ class ProviderManager:
if not cached_provider_credentials:
provider_credentials: dict[str, Any] = {}
if provider_records and provider_records[0].encrypted_config:
provider_credentials = _credentials_adapter.validate_json(provider_records[0].encrypted_config)
provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
@ -1165,10 +1162,8 @@ class ProviderManager:
if not cached_provider_model_credentials:
try:
provider_model_credentials = _credentials_adapter.validate_json(
load_balancing_model_config.encrypted_config
)
except (ValueError, JSONDecodeError):
provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
@ -1181,7 +1176,7 @@ class ProviderManager:
if variable in provider_model_credentials:
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable) or "",
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)

View File

@ -15,7 +15,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor,
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
from core.rag.entities import MetadataFilteringCondition
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.constant.query_type import QueryType
@ -182,9 +182,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
if metadata_filtering_conditions
else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,
@ -242,7 +240,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Dataset | None:
with Session(db.engine) as session:
return session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
@ -575,13 +573,15 @@ class RetrievalService:
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
summaries = session.scalars(
select(DocumentSegmentSummary).where(
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled.is_(True), # Only retrieve enabled summaries
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
).all()
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
@ -851,12 +851,12 @@ class RetrievalService:
def get_segment_attachment_info(
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
) -> SegmentAttachmentResult | None:
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == attachment_id).limit(1))
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
if upload_file:
attachment_binding = session.scalar(
select(SegmentAttachmentBinding)
attachment_binding = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id == upload_file.id)
.limit(1)
.first()
)
if attachment_binding:
attachment_info: AttachmentInfoDict = {
@ -875,12 +875,14 @@ class RetrievalService:
cls, attachment_ids: list[str], session: Session
) -> list[SegmentAttachmentInfoResult]:
attachment_infos: list[SegmentAttachmentInfoResult] = []
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(attachment_ids))).all()
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
if upload_files:
upload_file_ids = [upload_file.id for upload_file in upload_files]
attachment_bindings = session.scalars(
select(SegmentAttachmentBinding).where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
).all()
attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.attachment_id.in_(upload_file_ids))
.all()
)
attachment_binding_map = {binding.attachment_id: binding for binding in attachment_bindings}
if attachment_bindings:

View File

@ -37,12 +37,11 @@ class AnalyticdbVector(BaseVector):
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self.analyticdb_vector.create_collection_if_not_exists(dimension)
self.analyticdb_vector._create_collection_if_not_exists(dimension)
self.analyticdb_vector.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
self.analyticdb_vector.add_texts(documents, embeddings)
return []
def text_exists(self, id: str) -> bool:
return self.analyticdb_vector.text_exists(id)

View File

@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict
from typing import Any
from pydantic import BaseModel, model_validator
@ -13,13 +13,6 @@ from core.rag.models.document import Document
from extensions.ext_redis import redis_client
class AnalyticdbClientParamsDict(TypedDict):
access_key_id: str
access_key_secret: str
region_id: str
read_timeout: int
class AnalyticdbVectorOpenAPIConfig(BaseModel):
access_key_id: str
access_key_secret: str
@ -51,14 +44,13 @@ class AnalyticdbVectorOpenAPIConfig(BaseModel):
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
return values
def to_analyticdb_client_params(self) -> AnalyticdbClientParamsDict:
result: AnalyticdbClientParamsDict = {
def to_analyticdb_client_params(self):
return {
"access_key_id": self.access_key_id,
"access_key_secret": self.access_key_secret,
"region_id": self.region_id,
"read_timeout": self.read_timeout,
}
return result
class AnalyticdbVectorOpenAPI:
@ -123,7 +115,7 @@ class AnalyticdbVectorOpenAPI:
else:
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
def create_collection_if_not_exists(self, embedding_dimension: int):
def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException

View File

@ -1,6 +1,5 @@
import json
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any
@ -75,7 +74,7 @@ class AnalyticdbVectorBySql:
)
@contextmanager
def _get_cursor(self) -> Iterator[Any]:
def _get_cursor(self):
assert self.pool is not None, "Connection pool is not initialized"
conn = self.pool.getconn()
cur = conn.cursor()
@ -131,7 +130,7 @@ class AnalyticdbVectorBySql:
)
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
def create_collection_if_not_exists(self, embedding_dimension: int):
def _create_collection_if_not_exists(self, embedding_dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):

View File

@ -30,7 +30,7 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@ -85,12 +85,8 @@ class BaiduVector(BaseVector):
def get_type(self) -> str:
return VectorType.BAIDU
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))

View File

@ -1,12 +1,12 @@
import json
from typing import Any, TypedDict
from typing import Any
import chromadb
from chromadb import QueryResult, Settings # pyright: ignore[reportPrivateImportUsage]
from chromadb import QueryResult, Settings
from pydantic import BaseModel
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@ -15,15 +15,6 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset
class ChromaParamsDict(TypedDict):
host: str
port: int
ssl: bool
tenant: str
database: str
settings: Settings
class ChromaConfig(BaseModel):
host: str
port: int
@ -32,13 +23,14 @@ class ChromaConfig(BaseModel):
auth_provider: str | None = None
auth_credentials: str | None = None
def to_chroma_params(self) -> ChromaParamsDict:
def to_chroma_params(self):
settings = Settings(
# auth
chroma_client_auth_provider=self.auth_provider,
chroma_client_auth_credentials=self.auth_credentials,
)
result: ChromaParamsDict = {
return {
"host": self.host,
"port": self.port,
"ssl": False,
@ -46,7 +38,6 @@ class ChromaConfig(BaseModel):
"database": self.database,
"settings": settings,
}
return result
class ChromaVector(BaseVector):
@ -106,15 +97,14 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
document_ids_filter = kwargs.get("document_ids_filter")
results: QueryResult
if document_ids_filter:
results = collection.query(
results: QueryResult = collection.query(
query_embeddings=query_vector,
n_results=kwargs.get("top_k", 4),
where={"document_id": {"$in": document_ids_filter}}, # type: ignore
)
else:
results = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) # type: ignore
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# Check if results contain data
@ -155,10 +145,7 @@ class ChromaVectorFactory(AbstractVectorFactory):
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
index_struct_dict: VectorIndexStructDict = {
"type": VectorType.CHROMA,
"vector_store": {"class_prefix": collection_name},
}
index_struct_dict = {"type": VectorType.CHROMA, "vector_store": {"class_prefix": collection_name}}
dataset.index_struct = json.dumps(index_struct_dict)
return ChromaVector(
@ -166,8 +153,8 @@ class ChromaVectorFactory(AbstractVectorFactory):
config=ChromaConfig(
host=dify_config.CHROMA_HOST or "",
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT, # pyright: ignore[reportPrivateImportUsage]
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE, # pyright: ignore[reportPrivateImportUsage]
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
),

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, TypedDict
from typing import Any
from packaging import version
from pydantic import BaseModel, model_validator
@ -20,15 +20,6 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MilvusParamsDict(TypedDict):
uri: str
token: str | None
user: str | None
password: str | None
db_name: str
analyzer_params: str | None
class MilvusConfig(BaseModel):
"""
Configuration class for Milvus connection.
@ -59,11 +50,11 @@ class MilvusConfig(BaseModel):
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self) -> MilvusParamsDict:
def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
result: MilvusParamsDict = {
return {
"uri": self.uri,
"token": self.token,
"user": self.user,
@ -71,7 +62,6 @@ class MilvusConfig(BaseModel):
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
return result
class MilvusVector(BaseVector):
@ -362,7 +352,6 @@ class MilvusVector(BaseVector):
# Create Index params for the collection
index_params_obj = IndexParams()
assert index_params is not None
index_params_obj.add_index(field_name=Field.VECTOR, **index_params)
# Create Sparse Vector Index for the collection

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
import qdrant_client
from flask import current_app
@ -22,7 +22,7 @@ from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector, VectorIndexStructDict
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
@ -32,6 +32,7 @@ from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetCollectionBinding
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
@ -93,12 +94,8 @@ class QdrantVector(BaseVector):
def get_type(self) -> str:
return VectorType.QDRANT
def to_index_struct(self) -> VectorIndexStructDict:
result: VectorIndexStructDict = {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name},
}
return result
def to_index_struct(self):
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
@ -179,7 +176,7 @@ class QdrantVector(BaseVector):
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, filtered_metadatas, uuids, 64, self._group_id
):
self._client.upsert(collection_name=self._collection_name, points=cast("common_types.Points", points))
self._client.upsert(collection_name=self._collection_name, points=points)
added_ids.extend(batch_ids)
return added_ids
@ -471,7 +468,7 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client._load() # pyright: ignore[reportPrivateUsage]
self._client._load()
@classmethod
def _document_from_scored_point(

View File

@ -26,7 +26,7 @@ from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
Base: Any = declarative_base()
Base = declarative_base() # type: Any
class RelytConfig(BaseModel):

Some files were not shown because too many files have changed in this diff Show More