Compare commits

..

1 Commits

Author SHA1 Message Date
fc9853f938 chore: split components 2025-12-29 16:02:52 +08:00
787 changed files with 4835 additions and 82058 deletions

View File

@ -3,7 +3,6 @@
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true,
"ralph-wiggum@claude-plugins-official": true
"pyright-lsp@claude-plugins-official": true
}
}

View File

@ -1,73 +0,0 @@
---
name: frontend-code-review
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
---
# Frontend Code Review
## Intent
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
1. **Pending-change review** inspect staged/working-tree files slated for commit and flag checklist violations before submission.
2. **File-targeted review** review the specific file(s) the user names and report the relevant checklist findings.
Stick to the checklist below for every applicable file and mode.
## Checklist
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
## Review Process
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
## Required output
When invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```
# Code review
Found <N> urgent issues need to be fixed:
## 1 <brief description of bug>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each urgent issue) ...
Found <M> suggestions for improvement:
## 1 <brief description of suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each suggestion) ...
```
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
Don't compress the blank lines between sections; keep them as-is for readability.
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
### Template B (no issues)
```
## Code review
No issues found.
```

View File

@ -1,15 +0,0 @@
# Rule Catalog — Business Logic
## Can't use workflowStore in Node components
IsUrgent: True
### Description
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
### Suggested Fix
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.

View File

@ -1,44 +0,0 @@
# Rule Catalog — Code Quality
## Conditional class names use utility function
IsUrgent: True
Category: Code Quality
### Description
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
### Suggested Fix
```ts
import { cn } from '@/utils/classnames'
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
```
## Tailwind-first styling
IsUrgent: True
Category: Code Quality
### Description
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
## Classname ordering for easy overrides
### Description
When writing components, always place the incoming `className` prop after the components own class values so that downstream consumers can override or extend the styling. This keeps your components defaults but still lets external callers change or remove specific styles.
Example:
```tsx
import { cn } from '@/utils/classnames'
const Button = ({ className }) => {
return <div className={cn('bg-primary-600', className)}></div>
}
```

View File

@ -1,45 +0,0 @@
# Rule Catalog — Performance
## React Flow data usage
IsUrgent: True
Category: Performance
### Description
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
## Complex prop memoization
IsUrgent: True
Category: Performance
### Description
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
Wrong:
```tsx
<HeavyComp
config={{
provider: ...,
detail: ...
}}
/>
```
Right:
```tsx
const config = useMemo(() => ({
provider: ...,
detail: ...
}), [provider, detail]);
<HeavyComp
config={config}
/>
```

View File

@ -28,14 +28,17 @@ import userEvent from '@testing-library/user-event'
// i18n (automatically mocked)
// WHY: Global mock in web/vitest.setup.ts is auto-loaded by Vitest setup
// The global mock provides: useTranslation, Trans, useMixedTranslation, useGetLanguage
// No explicit mock needed for most tests
//
// No explicit mock needed - it returns translation keys as-is
// Override only if custom translations are required:
// import { createReactI18nextMock } from '@/test/i18n-mock'
// vi.mock('react-i18next', () => createReactI18nextMock({
// 'my.custom.key': 'Custom Translation',
// 'button.save': 'Save',
// vi.mock('react-i18next', () => ({
// useTranslation: () => ({
// t: (key: string) => {
// const customTranslations: Record<string, string> = {
// 'my.custom.key': 'Custom Translation',
// }
// return customTranslations[key] || key
// },
// }),
// }))
// Router (if component uses useRouter, usePathname, useSearchParams)

View File

@ -52,29 +52,23 @@ Modules are not mocked automatically. Use `vi.mock` in test files, or add global
### 1. i18n (Auto-loaded via Global Mock)
A global mock is defined in `web/vitest.setup.ts` and is auto-loaded by Vitest setup.
**No explicit mock needed** for most tests - it returns translation keys as-is.
The global mock provides:
- `useTranslation` - returns translation keys with namespace prefix
- `Trans` component - renders i18nKey and components
- `useMixedTranslation` (from `@/app/components/plugins/marketplace/hooks`)
- `useGetLanguage` (from `@/context/i18n`) - returns `'en-US'`
**Default behavior**: Most tests should use the global mock (no local override needed).
**For custom translations**: Use the helper function from `@/test/i18n-mock`:
For tests requiring custom translations, override the mock:
```typescript
import { createReactI18nextMock } from '@/test/i18n-mock'
vi.mock('react-i18next', () => createReactI18nextMock({
'my.custom.key': 'Custom translation',
'button.save': 'Save',
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
'my.custom.key': 'Custom translation',
}
return translations[key] || key
},
}),
}))
```
**Avoid**: Manually defining `useTranslation` mocks that just return the key - the global mock already does this.
### 2. Next.js Router
```typescript

View File

@ -110,16 +110,6 @@ jobs:
working-directory: ./web
run: pnpm run type-check:tsgo
- name: Web dead code check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run knip
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run build
superlinter:
name: SuperLinter
runs-on: ubuntu-latest

View File

@ -5,7 +5,6 @@ on:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
@ -19,8 +18,7 @@ jobs:
run:
working-directory: web
steps:
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
- uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@ -28,28 +26,21 @@ jobs:
- name: Check for file changes in i18n/en-US
id: check_files
run: |
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Install pnpm
@ -74,7 +65,7 @@ jobs:
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
run: pnpm run auto-gen-i18n ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'

1
.gitignore vendored
View File

@ -235,4 +235,3 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md

View File

@ -101,15 +101,6 @@ S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Workflow run and Conversation archive storage (S3-compatible)
ARCHIVE_STORAGE_ENABLED=false
ARCHIVE_STORAGE_ENDPOINT=
ARCHIVE_STORAGE_ARCHIVE_BUCKET=
ARCHIVE_STORAGE_EXPORT_BUCKET=
ARCHIVE_STORAGE_ACCESS_KEY=
ARCHIVE_STORAGE_SECRET_KEY=
ARCHIVE_STORAGE_REGION=auto
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
@ -502,8 +493,6 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log output format: text or json
LOG_OUTPUT_FORMAT=text
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s

View File

@ -1,8 +1,4 @@
exclude = [
"migrations/*",
".git",
".git/**",
]
exclude = ["migrations/*"]
line-length = 120
[format]

View File

@ -2,11 +2,9 @@ import logging
import time
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@ -27,35 +25,28 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
# Initialize logging context for this request
init_request_context()
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
@dify_app.after_request
def add_trace_headers(response):
def add_trace_id_header(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace headers to response", exc_info=True)
logger.warning("Failed to add trace ID to response header", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_headers
_ = add_trace_id_header
return dify_app

View File

@ -1,11 +1,9 @@
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
ArchiveStorageConfig,
NotionConfig,
SentryConfig,
):

View File

@ -1,43 +0,0 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class ArchiveStorageConfig(BaseSettings):
"""
Configuration settings for workflow run logs archiving storage.
"""
ARCHIVE_STORAGE_ENABLED: bool = Field(
description="Enable workflow run logs archiving to S3-compatible storage",
default=False,
)
ARCHIVE_STORAGE_ENDPOINT: str | None = Field(
description="URL of the S3-compatible storage endpoint (e.g., 'https://storage.example.com')",
default=None,
)
ARCHIVE_STORAGE_ARCHIVE_BUCKET: str | None = Field(
description="Name of the bucket to store archived workflow logs",
default=None,
)
ARCHIVE_STORAGE_EXPORT_BUCKET: str | None = Field(
description="Name of the bucket to store exported workflow runs",
default=None,
)
ARCHIVE_STORAGE_ACCESS_KEY: str | None = Field(
description="Access key ID for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_SECRET_KEY: str | None = Field(
description="Secret access key for authenticating with storage",
default=None,
)
ARCHIVE_STORAGE_REGION: str = Field(
description="Region for storage (use 'auto' if the provider supports it)",
default="auto",
)

View File

@ -587,11 +587,6 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
default="text",
)
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,

View File

@ -1,59 +1,62 @@
from __future__ import annotations
from flask_restx import Api, Namespace, fields
from typing import Any, TypeAlias
from libs.helper import AppIconUrlField
from pydantic import BaseModel, ConfigDict, computed_field
from core.file import helpers as file_helpers
from models.model import IconType
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
JSONObject: TypeAlias = dict[str, Any]
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
class SystemParameters(BaseModel):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
def build_system_parameters_model(api_or_ns: Api | Namespace):
"""Build the system parameters model for the API or Namespace."""
return api_or_ns.model("SystemParameters", parameters__system_parameters)
class Parameters(BaseModel):
opening_statement: str | None = None
suggested_questions: list[str]
suggested_questions_after_answer: JSONObject
speech_to_text: JSONObject
text_to_speech: JSONObject
retriever_resource: JSONObject
annotation_reply: JSONObject
more_like_this: JSONObject
user_input_form: list[JSONObject]
sensitive_word_avoidance: JSONObject
file_upload: JSONObject
system_parameters: SystemParameters
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}
class Site(BaseModel):
model_config = ConfigDict(from_attributes=True)
def build_parameters_model(api_or_ns: Api | Namespace):
"""Build the parameters model for the API or Namespace."""
copied_fields = parameters_fields.copy()
copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns))
return api_or_ns.model("Parameters", copied_fields)
title: str
chat_color_theme: str | None = None
chat_color_theme_inverted: bool
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
description: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
default_language: str
show_workflow_steps: bool
use_icon_as_answer_icon: bool
@computed_field(return_type=str | None) # type: ignore
@property
def icon_url(self) -> str | None:
if self.icon and self.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(self.icon)
return None
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}
def build_site_model(api_or_ns: Api | Namespace):
"""Build the site model for the API or Namespace."""
return api_or_ns.model("Site", site_fields)

View File

@ -1,4 +1,3 @@
import re
import uuid
from typing import Literal
@ -74,48 +73,6 @@ class AppListQuery(BaseModel):
raise ValueError("Invalid UUID format in tag_ids.") from exc
# XSS prevention: patterns that could lead to XSS attacks
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
_XSS_PATTERNS = [
r"<script[^>]*>.*?</script>", # Script tags
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
r"javascript:", # JavaScript protocol
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
r"<embed[^>]*>", # Embed tags (self-closing)
r"<link[^>]*>", # Link tags with javascript
]
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
"""
Validate that a string value doesn't contain potential XSS payloads.
Args:
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The original value if safe
Raises:
ValueError: If the value contains XSS patterns
"""
if value is None:
return None
value_lower = value.lower()
for pattern in _XSS_PATTERNS:
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
raise ValueError(
f"{field_name} contains invalid characters or patterns. "
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
)
return value
class CreateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
@ -124,11 +81,6 @@ class CreateAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class UpdateAppPayload(BaseModel):
name: str = Field(..., min_length=1, description="App name")
@ -139,11 +91,6 @@ class UpdateAppPayload(BaseModel):
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class CopyAppPayload(BaseModel):
name: str | None = Field(default=None, description="Name for the copied app")
@ -152,11 +99,6 @@ class CopyAppPayload(BaseModel):
icon: str | None = Field(default=None, description="Icon")
icon_background: str | None = Field(default=None, description="Icon background color")
@field_validator("name", "description", mode="before")
@classmethod
def validate_xss_safe(cls, value: str | None, info) -> str | None:
return _validate_xss_safe(value, info.field_name)
class AppExportQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")

View File

@ -13,6 +13,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
@ -176,12 +177,6 @@ annotation_hit_history_model = console_ns.model(
},
)
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",

View File

@ -124,7 +124,7 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account, oauth_new_user = _generate_account(provider, user_info)
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@ -159,10 +159,7 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
base_url = dify_config.CONSOLE_WEB_URL
query_char = "&" if "?" in base_url else "?"
target_url = f"{base_url}{query_char}oauth_new_user={str(oauth_new_user).lower()}"
response = redirect(target_url)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
@ -180,10 +177,9 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
return account
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
oauth_new_user = False
if account:
tenants = TenantService.get_join_tenants(account)
@ -197,7 +193,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
raise AccountRegisterError(
@ -225,4 +220,4 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
return account, oauth_new_user
return account

View File

@ -3,12 +3,10 @@ import uuid
from flask import request
from flask_restx import Resource, marshal
from pydantic import BaseModel, Field
from sqlalchemy import String, cast, func, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@ -145,29 +143,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
# PostgreSQL: Use jsonb_array_elements_text to properly handle Unicode/Chinese text
keywords_condition = func.array_to_string(
func.array(
select(func.jsonb_array_elements_text(cast(DocumentSegment.keywords, JSONB)))
.correlate(DocumentSegment)
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
keywords_condition,
)
)
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args.enabled.lower() != "all":
if args.enabled.lower() == "true":

View File

@ -1,7 +1,8 @@
from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -10,11 +11,7 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
@ -52,6 +49,7 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
@ -75,7 +73,7 @@ class ConversationListApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
@ -84,13 +82,6 @@ class ConversationListApi(InstalledAppResource):
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -114,7 +105,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return {"result": "success"}, 204
@console_ns.route(
@ -122,6 +113,7 @@ class ConversationApi(InstalledAppResource):
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
@ -136,14 +128,9 @@ class ConversationRenameApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
conversation = ConversationService.rename(
return ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -168,7 +155,7 @@ class ConversationPinApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@console_ns.route(
@ -187,4 +174,4 @@ class ConversationUnPinApi(InstalledAppResource):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}

View File

@ -2,7 +2,8 @@ import logging
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -22,8 +23,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from fields.message_fields import message_infinite_scroll_pagination_fields
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
@ -66,6 +66,7 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@ -77,20 +78,13 @@ class MessageListApi(InstalledAppResource):
args = MessageListQuery.model_validate(request.args.to_dict())
try:
pagination = MessageService.pagination_by_first_id(
return MessageService.pagination_by_first_id(
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -122,7 +116,7 @@ class MessageFeedbackApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@console_ns.route(
@ -207,4 +201,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logger.exception("internal server error.")
raise InternalServerError()
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
return {"data": questions}

View File

@ -1,3 +1,5 @@
from flask_restx import marshal_with
from controllers.common import fields
from controllers.console import console_ns
from controllers.console.app.error import AppUnavailableError
@ -11,6 +13,7 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
@marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
@ -34,8 +37,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta")

View File

@ -1,14 +1,14 @@
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -26,8 +26,28 @@ class SavedMessageCreatePayload(BaseModel):
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@ -37,19 +57,12 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
pagination = SavedMessageService.pagination_by_last_id(
return SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
@ -65,7 +78,7 @@ class SavedMessageListApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@console_ns.route(
@ -83,4 +96,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return {"result": "success"}, 204

View File

@ -20,6 +20,7 @@ from controllers.console.wraps import (
)
from core.db.session_factory import session_factory
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@ -986,6 +987,9 @@ class ToolProviderMCPApi(Resource):
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
# Final cache invalidation to ensure list views are up to date
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)
@console_ns.expect(parser_mcp_put)
@ -1032,6 +1036,9 @@ class ToolProviderMCPApi(Resource):
validation_result=validation_result,
)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@console_ns.expect(parser_mcp_delete)
@ -1046,6 +1053,9 @@ class ToolProviderMCPApi(Resource):
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=args["provider_id"])
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@ -1096,6 +1106,8 @@ class ToolMCPAuthApi(Resource):
credentials=provider_entity.credentials,
authed=True,
)
# Invalidate cache after updating credentials
ToolProviderListCache.invalidate_cache(tenant_id)
return {"result": "success"}
except MCPAuthError as e:
try:
@ -1109,16 +1121,22 @@ class ToolMCPAuthApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
# Invalidate cache after auth actions may have updated provider state
ToolProviderListCache.invalidate_cache(tenant_id)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
# Invalidate cache after clearing credentials
ToolProviderListCache.invalidate_cache(tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e

View File

@ -1,7 +1,7 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Api, Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field
@ -92,7 +92,7 @@ annotation_list_fields = {
}
def build_annotation_list_model(api_or_ns: Namespace):
def build_annotation_list_model(api_or_ns: Api | Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))

View File

@ -1,6 +1,6 @@
from flask_restx import Resource
from controllers.common.fields import Parameters
from controllers.common.fields import build_parameters_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@ -23,6 +23,7 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_parameters_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app parameters.
@ -44,8 +45,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return Parameters.model_validate(parameters).model_dump(mode="json")
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@service_api_ns.route("/meta")

View File

@ -3,7 +3,8 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@ -15,9 +16,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
ConversationDelete,
ConversationInfiniteScrollPagination,
SimpleConversation,
build_conversation_delete_model,
build_conversation_infinite_scroll_pagination_model,
build_simple_conversation_model,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
@ -104,6 +105,7 @@ class ConversationApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List all conversations for the current user.
@ -118,7 +120,7 @@ class ConversationApi(Resource):
try:
with Session(db.engine) as session:
pagination = ConversationService.pagination_by_last_id(
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@ -127,13 +129,6 @@ class ConversationApi(Resource):
invoke_from=InvokeFrom.SERVICE_API,
sort_by=query_args.sort_by,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@ -151,6 +146,7 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
@ -163,7 +159,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ConversationDelete(result="success").model_dump(mode="json"), 204
return {"result": "success"}, 204
@service_api_ns.route("/conversations/<uuid:c_id>/name")
@ -180,6 +176,7 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
@ -191,14 +188,7 @@ class ConversationRenameApi(Resource):
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
conversation = ConversationService.rename(
app_model, conversation_id, end_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,10 +1,11 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@ -13,8 +14,10 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@ -45,6 +48,49 @@ class FeedbackListQuery(BaseModel):
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace."""
# First build the nested models
feedback_model = build_feedback_model(api_or_ns)
agent_thought_model = build_agent_thought_model(api_or_ns)
message_file_model = build_message_file_model(api_or_ns)
# Then build the message fields with nested models
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_model)),
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"status": fields.String,
"error": fields.String,
}
return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first
message_model = build_message_model(api_or_ns)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_model)),
}
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@ -58,6 +104,7 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List messages in a conversation.
@ -72,16 +119,9 @@ class MessageListApi(Resource):
first_id = str(query_args.first_id) if query_args.first_id else None
try:
pagination = MessageService.pagination_by_first_id(
return MessageService.pagination_by_first_id(
app_model, end_user, conversation_id, first_id, query_args.limit
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -122,7 +162,7 @@ class MessageFeedbackApi(Resource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@service_api_ns.route("/app/feedbacks")

View File

@ -1,7 +1,7 @@
from flask_restx import Resource
from werkzeug.exceptions import Forbidden
from controllers.common.fields import Site as SiteResponse
from controllers.common.fields import build_site_model
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
@ -23,6 +23,7 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
@service_api_ns.marshal_with(build_site_model(service_api_ns))
def get(self, app_model: App):
"""Retrieve app site info.
@ -37,4 +38,4 @@ class AppSiteApi(Resource):
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return SiteResponse.model_validate(site).model_dump(mode="json")
return site

View File

@ -3,7 +3,7 @@ from typing import Any, Literal
from dateutil.parser import isoparse
from flask import request
from flask_restx import Namespace, Resource, fields
from flask_restx import Api, Namespace, Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
@ -78,7 +78,7 @@ workflow_run_fields = {
}
def build_workflow_run_model(api_or_ns: Namespace):
def build_workflow_run_model(api_or_ns: Api | Namespace):
"""Build the workflow run model for the API or Namespace."""
return api_or_ns.model("WorkflowRun", workflow_run_fields)

View File

@ -1,7 +1,7 @@
import logging
from flask import request
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, ConfigDict, Field
from werkzeug.exceptions import Unauthorized
@ -50,6 +50,7 @@ class AppParameterApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
@ -68,8 +69,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return fields.Parameters.model_validate(parameters).model_dump(mode="json")
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
@web_ns.route("/meta")

View File

@ -1,6 +1,5 @@
from flask_restx import reqparse
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@ -9,11 +8,7 @@ from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
@ -59,6 +54,7 @@ class ConversationListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -86,7 +82,7 @@ class ConversationListApi(WebApiResource):
try:
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@ -96,19 +92,16 @@ class ConversationListApi(WebApiResource):
pinned=pinned,
sort_by=args["sort_by"],
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>")
class ConversationApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -122,6 +115,7 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -132,7 +126,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return {"result": "success"}, 204
@web_ns.route("/conversations/<uuid:c_id>/name")
@ -161,6 +155,7 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -176,20 +171,17 @@ class ConversationRenameApi(WebApiResource):
args = parser.parse_args()
try:
conversation = ConversationService.rename(
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>/pin")
class ConversationPinApi(WebApiResource):
pin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -203,6 +195,7 @@ class ConversationPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -215,11 +208,15 @@ class ConversationPinApi(WebApiResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@web_ns.route("/conversations/<uuid:c_id>/unpin")
class ConversationUnPinApi(WebApiResource):
unpin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@ -233,6 +230,7 @@ class ConversationUnPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -241,4 +239,4 @@ class ConversationUnPinApi(WebApiResource):
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}

View File

@ -2,7 +2,8 @@ import logging
from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@ -21,10 +22,11 @@ from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from libs import helper
from libs.helper import uuid_value
from libs.helper import TimestampField, uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@ -68,6 +70,29 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
@ -96,6 +121,7 @@ class MessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -105,16 +131,9 @@ class MessageListApi(WebApiResource):
query = MessageListQuery.model_validate(raw_args)
try:
pagination = MessageService.pagination_by_first_id(
return MessageService.pagination_by_first_id(
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
adapter = TypeAdapter(WebMessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return WebMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@ -123,6 +142,10 @@ class MessageListApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(WebApiResource):
feedback_response_fields = {
"result": fields.String,
}
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@ -147,6 +170,7 @@ class MessageFeedbackApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -163,7 +187,7 @@ class MessageFeedbackApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
@ -223,6 +247,10 @@ class MessageMoreLikeThisApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
suggested_questions_response_fields = {
"data": fields.List(fields.String),
}
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@ -236,6 +264,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -248,6 +277,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
# so we can directly return it
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
@ -266,4 +296,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
logger.exception("internal server error.")
raise InternalServerError()
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")
return {"data": questions}

View File

@ -1,20 +1,40 @@
from flask_restx import reqparse
from flask_restx import fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import uuid_value
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
post_response_fields = {
"result": fields.String,
}
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
@ -38,6 +58,7 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -49,14 +70,7 @@ class SavedMessageListApi(WebApiResource):
)
args = parser.parse_args()
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@ -75,6 +89,7 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -87,11 +102,15 @@ class SavedMessageListApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return ResultResponse(result="success").model_dump(mode="json")
return {"result": "success"}
@web_ns.route("/saved-messages/<uuid:message_id>")
class SavedMessageApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@ -105,6 +124,7 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@ -113,4 +133,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return {"result": "success"}, 204

View File

@ -22,7 +22,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -166,11 +165,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:

View File

@ -25,7 +25,6 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
@ -223,10 +222,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
final_answer += response + "\n"
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and tool_calls:
raise AgentMaxIterationError(app_config.agent.max_iteration)
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:

View File

@ -30,6 +30,7 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@ -43,6 +44,7 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@ -92,6 +94,7 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@ -88,41 +88,7 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@ -140,21 +106,18 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Preserve the user-provided Host header
# httpx may override the Host header when using a proxy
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host

View File

@ -0,0 +1,58 @@
import json
import logging
from typing import Any, cast
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from extensions.ext_redis import redis_client, redis_fallback
logger = logging.getLogger(__name__)
class ToolProviderListCache:
"""Cache for tool provider lists"""
CACHE_TTL = 300 # 5 minutes
@staticmethod
def _generate_cache_key(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> str:
"""Generate cache key for tool providers list"""
type_filter = typ or "all"
return f"tool_providers:tenant_id:{tenant_id}:type:{type_filter}"
@staticmethod
@redis_fallback(default_return=None)
def get_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral = None) -> list[dict[str, Any]] | None:
"""Get cached tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
cached_data = redis_client.get(cache_key)
if cached_data:
try:
return json.loads(cached_data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
logger.warning("Failed to decode cached tool providers data")
return None
return None
@staticmethod
@redis_fallback()
def set_cached_providers(tenant_id: str, typ: ToolProviderTypeApiLiteral, providers: list[dict[str, Any]]):
"""Cache tool providers"""
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.setex(cache_key, ToolProviderListCache.CACHE_TTL, json.dumps(providers))
@staticmethod
@redis_fallback()
def invalidate_cache(tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""Invalidate cache for tool providers"""
if typ:
# Invalidate specific type cache
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, typ)
redis_client.delete(cache_key)
else:
# Invalidate all caches for this tenant
keys = ["builtin", "model", "api", "workflow", "mcp"]
pipeline = redis_client.pipeline()
for key in keys:
cache_key = ToolProviderListCache._generate_cache_key(tenant_id, cast(ToolProviderTypeApiLiteral, key))
pipeline.delete(cache_key)
pipeline.execute()

View File

@ -103,60 +103,3 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@ -1,20 +0,0 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@ -1,35 +0,0 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@ -1,94 +0,0 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@ -1,107 +0,0 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:
return orjson.dumps(log_dict).decode("utf-8")
except TypeError:
# Fallback: convert non-serializable objects to string
import json
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "")
span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@ -100,6 +100,7 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@ -122,6 +123,7 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@ -155,6 +157,7 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)

View File

@ -285,7 +285,7 @@ class ModelProviderFactory:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_small_dark)
:param icon_type: icon type (icon_small or icon_large)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@ -309,7 +309,13 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")

View File

@ -103,9 +103,6 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@ -117,31 +114,6 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
if not dify_config.ENABLE_OTEL:
return
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@ -331,6 +331,7 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)

View File

@ -27,44 +27,26 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL but keep Markdown image URLs and link URLs
# Replace the ENTIRE markdown link/image with a single placeholder to protect
# the link text (which might also be a URL) from being removed
markdown_link_pattern = r"\[([^\]]*)\]\((https?://[^)]+)\)"
markdown_image_pattern = r"!\[.*?\]\((https?://[^)]+)\)"
placeholders: list[tuple[str, str, str]] = [] # (type, text, url)
# Remove URL but keep Markdown image URLs
# First, temporarily replace Markdown image URLs with a placeholder
markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
def replace_markdown_with_placeholder(match, placeholders=placeholders):
link_type = "link"
link_text = match.group(1)
url = match.group(2)
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, link_text, url))
return placeholder
def replace_image_with_placeholder(match, placeholders=placeholders):
link_type = "image"
def replace_with_placeholder(match, placeholders=placeholders):
url = match.group(1)
placeholder = f"__MARKDOWN_PLACEHOLDER_{len(placeholders)}__"
placeholders.append((link_type, "image", url))
return placeholder
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
# Protect markdown links first
text = re.sub(markdown_link_pattern, replace_markdown_with_placeholder, text)
# Then protect markdown images
text = re.sub(markdown_image_pattern, replace_image_with_placeholder, text)
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://\S+"
url_pattern = r"https?://[^\s)]+"
text = re.sub(url_pattern, "", text)
# Restore the Markdown links and images
for i, (link_type, text_or_alt, url) in enumerate(placeholders):
placeholder = f"__MARKDOWN_PLACEHOLDER_{i}__"
if link_type == "link":
text = text.replace(placeholder, f"[{text_or_alt}]({url})")
else: # image
text = text.replace(placeholder, f"![{text_or_alt}]({url})")
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
return text
def filter_string(self, text):

View File

@ -1,5 +1,4 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@ -37,8 +36,6 @@ default_retrieval_model = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
@ -109,12 +106,7 @@ class RetrievalService:
)
)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
break
concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))
@ -218,7 +210,6 @@ class RetrievalService:
)
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -312,7 +303,6 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@classmethod
@ -361,7 +351,6 @@ class RetrievalService:
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))
@staticmethod
@ -674,14 +663,7 @@ class RetrievalService:
document_ids_filter=document_ids_filter,
)
)
# Use as_completed for early error propagation - cancel remaining futures on first error
if futures:
for future in concurrent.futures.as_completed(futures, timeout=300):
if future.exception():
# Cancel remaining futures to avoid unnecessary waiting
for f in futures:
f.cancel()
break
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
if exceptions:
raise ValueError(";\n".join(exceptions))

View File

@ -112,7 +112,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@ -148,7 +148,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:

View File

@ -1,57 +1,25 @@
"""Abstract interface for document loader implementations."""
import contextlib
import io
import logging
import uuid
from collections.abc import Iterator
import pypdfium2
import pypdfium2.raw as pdfium_c
from configs import dify_config
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from models.enums import CreatorUserRole
from models.model import UploadFile
logger = logging.getLogger(__name__)
class PdfExtractor(BaseExtractor):
"""
PdfExtractor is used to extract text and images from PDF files.
"""Load pdf files.
Args:
file_path: Path to the PDF file.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
file_cache_key: Optional cache key for the extracted text.
file_path: Path to the file to load.
"""
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
IMAGE_FORMATS = [
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
(b"GIF8", "gif", "image/gif"),
(b"BM", "bmp", "image/bmp"),
(b"II*\x00", "tiff", "image/tiff"),
(b"MM\x00*", "tiff", "image/tiff"),
(b"II+\x00", "tiff", "image/tiff"),
(b"MM\x00+", "tiff", "image/tiff"),
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
def __init__(self, file_path: str, file_cache_key: str | None = None):
"""Initialize with file path."""
self._file_path = file_path
self._tenant_id = tenant_id
self._user_id = user_id
self._file_cache_key = file_cache_key
def extract(self) -> list[Document]:
@ -82,6 +50,7 @@ class PdfExtractor(BaseExtractor):
def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
import pypdfium2 # type: ignore
with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
@ -90,87 +59,8 @@ class PdfExtractor(BaseExtractor):
text_page = page.get_textpage()
content = text_page.get_text_range()
text_page.close()
image_content = self._extract_images(page)
if image_content:
content += "\n" + image_content
page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata)
finally:
pdf_reader.close()
def _extract_images(self, page) -> str:
"""
Extract images from a PDF page, save them to storage and database,
and return markdown image links.
Args:
page: pypdfium2 page object.
Returns:
Markdown string containing links to the extracted images.
"""
image_content = []
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:
try:
# Extract image bytes
img_byte_arr = io.BytesIO()
# Extract DCTDecode (JPEG) and JPXDecode (JPEG 2000) images directly
# Fallback to png for other formats
obj.extract(img_byte_arr, fb_format="png")
img_bytes = img_byte_arr.getvalue()
if not img_bytes:
continue
header = img_bytes[: self.MAX_MAGIC_LEN]
image_ext = None
mime_type = None
for magic, ext, mime in self.IMAGE_FORMATS:
if header.startswith(magic):
image_ext = ext
mime_type = mime
break
if not image_ext or not mime_type:
continue
file_uuid = str(uuid.uuid4())
file_key = "image_files/" + self._tenant_id + "/" + file_uuid + "." + image_ext
storage.save(file_key, img_bytes)
# save file to db
upload_file = UploadFile(
tenant_id=self._tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=file_key,
name=file_key,
size=len(img_bytes),
extension=image_ext,
mime_type=mime_type,
created_by=self._user_id,
created_by_role=CreatorUserRole.ACCOUNT,
created_at=naive_utc_now(),
used=True,
used_by=self._user_id,
used_at=naive_utc_now(),
)
upload_files.append(upload_file)
image_content.append(f"![image]({base_url}/files/{upload_file.id}/file-preview)")
except Exception as e:
logger.warning("Failed to extract image from PDF: %s", e)
continue
except Exception as e:
logger.warning("Failed to get objects from PDF page: %s", e)
if upload_files:
db.session.add_all(upload_files)
db.session.commit()
return "\n".join(image_content)

View File

@ -515,11 +515,7 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
dataset_count = len(available_datasets)
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
@ -538,9 +534,6 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
@ -564,26 +557,12 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
# Poll threads with short timeout to detect errors quickly (fail-fast)
while any(t.is_alive() for t in all_threads):
for thread in all_threads:
thread.join(timeout=0.1)
if thread_exceptions:
cancel_event.set()
break
if thread_exceptions:
break
if thread_exceptions:
raise thread_exceptions[0]
for thread in all_threads:
thread.join()
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)
if all_documents:
@ -1425,57 +1404,42 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
dataset_count: int,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
try:
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
# Check for cancellation signal
if cancel_event and cancel_event.is_set():
break
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
# Poll threads with short timeout to respond quickly to cancellation
while any(t.is_alive() for t in threads):
for thread in threads:
thread.join(timeout=0.1)
if cancel_event and cancel_event.is_set():
break
if cancel_event and cancel_event.is_set():
break
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
@ -1506,8 +1470,3 @@ class DatasetRetrieval:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)

View File

@ -378,7 +378,7 @@ class ApiBasedToolSchemaParser:
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], ApiProviderSchemaType]:
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle

View File

@ -4,7 +4,6 @@ import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Preserves markdown links like [text](url) at the start.
Args:
text (str): The input text to process.
@ -12,11 +11,6 @@ def remove_leading_symbols(text: str) -> str:
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Check if text starts with a markdown link - preserve it
markdown_link_pattern = r"^\[([^\]]+)\]\((https?://[^)]+)\)"
if re.match(markdown_link_pattern, text):
return text
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'

View File

@ -54,6 +54,7 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("app not found")
user = session.get(Account, db_provider.user_id) if db_provider.user_id else None
controller = WorkflowToolProviderController(
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
@ -66,7 +67,7 @@ class WorkflowToolProviderController(ToolProviderController):
credentials_schema=[],
plugin_id=None,
),
provider_id=db_provider.id,
provider_id="",
)
controller.tools = [

View File

@ -60,7 +60,6 @@ class SkipPropagator:
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node

View File

@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import TYPE_CHECKING, Any, ClassVar, cast
from typing import Any, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@ -12,7 +13,6 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@ -20,41 +20,9 @@ from .exc import (
OutputValidationError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -67,16 +35,11 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
return "1"
@ -97,8 +60,7 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
result = CodeExecutor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@ -113,12 +75,6 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@ -129,10 +85,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if len(value) > self._limits.max_string_length:
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {self._limits.max_string_length} characters"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
)
return value.replace("\x00", "")
@ -153,20 +109,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if value > self._limits.max_number or value < self._limits.min_number:
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > self._limits.max_precision:
if precision > dify_config.CODE_MAX_PRECISION:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {self._limits.max_precision} digits."
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
)
return value
@ -181,8 +137,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > self._limits.max_depth:
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@ -316,10 +272,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > self._limits.max_number_array_length:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {self._limits.max_number_array_length} elements."
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
for i, inner_value in enumerate(value):
@ -349,10 +305,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > self._limits.max_string_array_length:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {self._limits.max_string_array_length} elements."
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
)
transformed_result[output_name] = [
@ -370,10 +326,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > self._limits.max_object_array_length:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {self._limits.max_object_array_length} elements."
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
)
for i, value in enumerate(result[output_name]):

View File

@ -1,13 +0,0 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class CodeNodeLimits:
max_string_length: int
max_number: int | float
min_number: int | float
max_precision: int
max_depth: int
max_number_array_length: int
max_string_array_length: int
max_object_array_length: int

View File

@ -1,16 +1,10 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@ -33,27 +27,9 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@ -96,17 +72,6 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
return node_class(
id=node_id,
config=node_config,

View File

@ -12,8 +12,9 @@ from dify_app import DifyApp
def _get_celery_ssl_options() -> dict[str, Any] | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Use REDIS_USE_SSL for consistency with the main Redis client
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.BROKER_USE_SSL:
if not dify_config.REDIS_USE_SSL:
return None
# Check if Celery is actually using Redis
@ -46,11 +47,7 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
from core.logging.context import init_request_context
with app.app_context():
# Initialize logging context for this task (similar to before_request in Flask)
init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}

View File

@ -1,19 +1,18 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@ -26,53 +25,27 @@ def init_app(app: DifyApp):
)
)
# Console handler
# Always add StreamHandler to log to console
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
handler.addFilter(RequestIdFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@ -84,51 +57,34 @@ def _apply_timezone(handlers: list[logging.Handler]):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in handlers:
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter.converter = time_converter # type: ignore[attr-defined]
handler.formatter.converter = time_converter
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
def get_request_id():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
def get_request_id() -> str:
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
return True
class RequestIdFormatter(logging.Formatter):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@ -137,7 +93,6 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@ -19,43 +19,26 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@ -12,7 +12,7 @@ annotation_fields = {
}
def build_annotation_model(api_or_ns: Namespace):
def build_annotation_model(api_or_ns: Api | Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)

View File

@ -1,338 +1,236 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.file import File
JSONValue: TypeAlias = Any
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class MessageFile(ResponseModel):
id: str
filename: str
type: str
url: str | None = None
mime_type: str | None = None
size: int | None = None
transfer_method: str
belongs_to: str | None = None
upload_file_id: str | None = None
@field_validator("transfer_method", mode="before")
@classmethod
def _normalize_transfer_method(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
class SimpleConversation(ResponseModel):
id: str
name: str
inputs: dict[str, JSONValue]
status: str
introduction: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class ConversationInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SimpleConversation]
class ConversationDelete(ResponseModel):
result: str
class ResultResponse(ResponseModel):
result: str
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class Feedback(ResponseModel):
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account: SimpleAccount | None = None
class Annotation(ResponseModel):
id: str
question: str | None = None
content: str
account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AnnotationHitHistory(ResponseModel):
annotation_id: str
annotation_create_account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AgentThought(ResponseModel):
id: str
chain_id: str | None = None
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_labels: JSONValue
tool_input: str | None = None
created_at: int | None = None
observation: str | None = None
files: list[str]
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
@model_validator(mode="after")
def _fallback_chain_id(self):
if self.chain_id is None and self.message_chain_id:
self.chain_id = self.message_chain_id
return self
class MessageDetail(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue
message_tokens: int
answer: str
answer_tokens: int
provider_response_latency: float
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback]
workflow_run_id: str | None = None
annotation: Annotation | None = None
annotation_hit_history: AnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
metadata: JSONValue
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class FeedbackStat(ResponseModel):
like: int
dislike: int
class StatusCount(ResponseModel):
success: int
failed: int
partial_success: int
class ModelConfig(ResponseModel):
opening_statement: str | None = None
suggested_questions: JSONValue | None = None
model: JSONValue | None = None
user_input_form: JSONValue | None = None
pre_prompt: str | None = None
agent_mode: JSONValue | None = None
class SimpleModelConfig(ResponseModel):
model: JSONValue | None = None
pre_prompt: str | None = None
class SimpleMessageDetail(ResponseModel):
inputs: dict[str, JSONValue]
query: str
message: str
answer: str
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
class Conversation(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotation: Annotation | None = None
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
message: SimpleMessageDetail | None = None
class ConversationPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[Conversation]
class ConversationMessageDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message: MessageDetail | None = None
class ConversationWithSummary(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
name: str
summary: str
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
status_count: StatusCount | None = None
class ConversationWithSummaryPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationWithSummary]
class ConversationDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
introduction: str | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValue) -> JSONValue:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value
def message_text(value: JSONValue) -> str:
if isinstance(value, list) and value:
first = value[0]
if isinstance(first, dict):
text = first.get("text")
if isinstance(text, str):
return text
return ""
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
if value is None:
return {}
if isinstance(value, dict):
return value
if hasattr(value, "to_dict"):
return value.to_dict()
return {}
from flask_restx import Api, Namespace, fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
from .raws import FilesContainedField
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
feedback_fields = {
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_fields, allow_null=True),
}
annotation_fields = {
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
annotation_hit_history_fields = {
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
message_file_fields = {
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
}
def build_message_file_model(api_or_ns: Api | Namespace):
"""Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields)
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
message_detail_fields = {
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_fields)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
model_config_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
}
simple_model_config_fields = {
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
}
simple_message_detail_fields = {
"inputs": FilesContainedField,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
}
conversation_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"from_account_name": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"model_config": fields.Nested(simple_model_config_fields),
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
}
conversation_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
}
conversation_message_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_fields),
"message": fields.Nested(message_detail_fields, attribute="first_message"),
}
conversation_with_summary_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"from_account_name": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"status_count": fields.Nested(status_count_fields),
}
conversation_with_summary_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
}
conversation_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
}
simple_conversation_fields = {
"id": fields.String,
"name": fields.String,
"inputs": FilesContainedField,
"status": fields.String,
"introduction": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
conversation_delete_fields = {
"result": fields.String,
}
conversation_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(simple_conversation_fields)),
}
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns)
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
def build_conversation_delete_model(api_or_ns: Api | Namespace):
"""Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
def build_simple_conversation_model(api_or_ns: Api | Namespace):
"""Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@ -29,12 +29,12 @@ conversation_variable_infinite_scroll_pagination_fields = {
}
def build_conversation_variable_model(api_or_ns: Namespace):
def build_conversation_variable_model(api_or_ns: Api | Namespace):
"""Build the conversation variable model for the API or Namespace."""
return api_or_ns.model("ConversationVariable", conversation_variable_fields)
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Namespace):
def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
"""Build the conversation variable infinite scroll pagination model for the API or Namespace."""
# Build the nested variable model first
conversation_variable_model = build_conversation_variable_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ simple_end_user_fields = {
}
def build_simple_end_user_model(api_or_ns: Namespace):
def build_simple_end_user_model(api_or_ns: Api | Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from libs.helper import TimestampField
@ -14,7 +14,7 @@ upload_config_fields = {
}
def build_upload_config_model(api_or_ns: Namespace):
def build_upload_config_model(api_or_ns: Api | Namespace):
"""Build the upload config model for the API or Namespace.
Args:
@ -39,7 +39,7 @@ file_fields = {
}
def build_file_model(api_or_ns: Namespace):
def build_file_model(api_or_ns: Api | Namespace):
"""Build the file model for the API or Namespace.
Args:
@ -57,7 +57,7 @@ remote_file_info_fields = {
}
def build_remote_file_info_model(api_or_ns: Namespace):
def build_remote_file_info_model(api_or_ns: Api | Namespace):
"""Build the remote file info model for the API or Namespace.
Args:
@ -81,7 +81,7 @@ file_fields_with_signed_url = {
}
def build_file_with_signed_url_model(api_or_ns: Namespace):
def build_file_with_signed_url_model(api_or_ns: Api | Namespace):
"""Build the file with signed URL model for the API or Namespace.
Args:

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from libs.helper import AvatarUrlField, TimestampField
@ -9,7 +9,7 @@ simple_account_fields = {
}
def build_simple_account_model(api_or_ns: Namespace):
def build_simple_account_model(api_or_ns: Api | Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)

View File

@ -1,137 +1,77 @@
from __future__ import annotations
from flask_restx import Api, Namespace, fields
from datetime import datetime
from typing import TypeAlias
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from pydantic import BaseModel, ConfigDict, Field, field_validator
from .raws import FilesContainedField
from core.file import File
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue
feedback_fields = {
"rating": fields.String,
}
class ResponseModel(BaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
def build_feedback_model(api_or_ns: Api | Namespace):
"""Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields)
class SimpleFeedback(ResponseModel):
rating: str | None = None
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
class RetrieverResource(ResponseModel):
id: str
message_id: str
position: int
dataset_id: str | None = None
dataset_name: str | None = None
document_id: str | None = None
document_name: str | None = None
data_source_type: str | None = None
segment_id: str | None = None
score: float | None = None
hit_count: int | None = None
word_count: int | None = None
segment_position: int | None = None
index_node_hash: str | None = None
content: str | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
def build_agent_thought_model(api_or_ns: Api | Namespace):
"""Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields)
class MessageListItem(ResponseModel):
id: str
conversation_id: str
parent_message_id: str | None = None
inputs: dict[str, JSONValueType]
query: str
answer: str = Field(validation_alias="re_sign_file_url_answer")
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
retriever_resources: list[RetrieverResource]
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
status: str
error: str | None = None
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"status": fields.String,
"error": fields.String,
}
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class WebMessageListItem(MessageListItem):
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
class MessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[MessageListItem]
class WebMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[WebMessageListItem]
class SavedMessageItem(ResponseModel):
id: str
inputs: dict[str, JSONValueType]
query: str
answer: str
message_files: list[MessageFile]
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
created_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class SavedMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SavedMessageItem]
class SuggestedQuestionsResponse(ResponseModel):
data: list[str]
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValueType) -> JSONValueType:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
dataset_tag_fields = {
"id": fields.String,
@ -8,5 +8,5 @@ dataset_tag_fields = {
}
def build_dataset_tag_fields(api_or_ns: Namespace):
def build_dataset_tag_fields(api_or_ns: Api | Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields
from fields.member_fields import build_simple_account_model, simple_account_fields
@ -17,7 +17,7 @@ workflow_app_log_partial_fields = {
}
def build_workflow_app_log_partial_model(api_or_ns: Namespace):
def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace):
"""Build the workflow app log partial model for the API or Namespace."""
workflow_run_model = build_workflow_run_for_log_model(api_or_ns)
simple_account_model = build_simple_account_model(api_or_ns)
@ -43,7 +43,7 @@ workflow_app_log_pagination_fields = {
}
def build_workflow_app_log_pagination_model(api_or_ns: Namespace):
def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace):
"""Build the workflow app log pagination model for the API or Namespace."""
# Build the nested partial model first
workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns)

View File

@ -1,4 +1,4 @@
from flask_restx import Namespace, fields
from flask_restx import Api, Namespace, fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@ -19,7 +19,7 @@ workflow_run_for_log_fields = {
}
def build_workflow_run_for_log_model(api_or_ns: Namespace):
def build_workflow_run_for_log_model(api_or_ns: Api | Namespace):
return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields)

View File

@ -1,347 +0,0 @@
"""
Archive Storage Client for S3-compatible storage.
This module provides a dedicated storage client for archiving or exporting logs
to S3-compatible object storage.
"""
import base64
import datetime
import gzip
import hashlib
import logging
from collections.abc import Generator
from typing import Any, cast
import boto3
import orjson
from botocore.client import Config
from botocore.exceptions import ClientError
from configs import dify_config
logger = logging.getLogger(__name__)
class ArchiveStorageError(Exception):
"""Base exception for archive storage operations."""
pass
class ArchiveStorageNotConfiguredError(ArchiveStorageError):
"""Raised when archive storage is not properly configured."""
pass
class ArchiveStorage:
"""
S3-compatible storage client for archiving or exporting.
This client provides methods for storing and retrieving archived data in JSONL+gzip format.
"""
def __init__(self, bucket: str):
if not dify_config.ARCHIVE_STORAGE_ENABLED:
raise ArchiveStorageNotConfiguredError("Archive storage is not enabled")
if not bucket:
raise ArchiveStorageNotConfiguredError("Archive storage bucket is not configured")
if not all(
[
dify_config.ARCHIVE_STORAGE_ENDPOINT,
bucket,
dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
dify_config.ARCHIVE_STORAGE_SECRET_KEY,
]
):
raise ArchiveStorageNotConfiguredError(
"Archive storage configuration is incomplete. "
"Required: ARCHIVE_STORAGE_ENDPOINT, ARCHIVE_STORAGE_ACCESS_KEY, "
"ARCHIVE_STORAGE_SECRET_KEY, and a bucket name"
)
self.bucket = bucket
self.client = boto3.client(
"s3",
endpoint_url=dify_config.ARCHIVE_STORAGE_ENDPOINT,
aws_access_key_id=dify_config.ARCHIVE_STORAGE_ACCESS_KEY,
aws_secret_access_key=dify_config.ARCHIVE_STORAGE_SECRET_KEY,
region_name=dify_config.ARCHIVE_STORAGE_REGION,
config=Config(s3={"addressing_style": "path"}),
)
# Verify bucket accessibility
try:
self.client.head_bucket(Bucket=self.bucket)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "404":
raise ArchiveStorageNotConfiguredError(f"Archive bucket '{self.bucket}' does not exist")
elif error_code == "403":
raise ArchiveStorageNotConfiguredError(f"Access denied to archive bucket '{self.bucket}'")
else:
raise ArchiveStorageError(f"Failed to access archive bucket: {e}")
def put_object(self, key: str, data: bytes) -> str:
"""
Upload an object to the archive storage.
Args:
key: Object key (path) within the bucket
data: Binary data to upload
Returns:
MD5 checksum of the uploaded data
Raises:
ArchiveStorageError: If upload fails
"""
checksum = hashlib.md5(data).hexdigest()
try:
self.client.put_object(
Bucket=self.bucket,
Key=key,
Body=data,
ContentMD5=self._content_md5(data),
)
logger.debug("Uploaded object: %s (size=%d, checksum=%s)", key, len(data), checksum)
return checksum
except ClientError as e:
raise ArchiveStorageError(f"Failed to upload object '{key}': {e}")
def get_object(self, key: str) -> bytes:
"""
Download an object from the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
Binary data of the object
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
return response["Body"].read()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to download object '{key}': {e}")
def get_object_stream(self, key: str) -> Generator[bytes, None, None]:
"""
Stream an object from the archive storage.
Args:
key: Object key (path) within the bucket
Yields:
Chunks of binary data
Raises:
ArchiveStorageError: If download fails
FileNotFoundError: If object does not exist
"""
try:
response = self.client.get_object(Bucket=self.bucket, Key=key)
yield from response["Body"].iter_chunks()
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == "NoSuchKey":
raise FileNotFoundError(f"Archive object not found: {key}")
raise ArchiveStorageError(f"Failed to stream object '{key}': {e}")
def object_exists(self, key: str) -> bool:
"""
Check if an object exists in the archive storage.
Args:
key: Object key (path) within the bucket
Returns:
True if object exists, False otherwise
"""
try:
self.client.head_object(Bucket=self.bucket, Key=key)
return True
except ClientError:
return False
def delete_object(self, key: str) -> None:
"""
Delete an object from the archive storage.
Args:
key: Object key (path) within the bucket
Raises:
ArchiveStorageError: If deletion fails
"""
try:
self.client.delete_object(Bucket=self.bucket, Key=key)
logger.debug("Deleted object: %s", key)
except ClientError as e:
raise ArchiveStorageError(f"Failed to delete object '{key}': {e}")
def generate_presigned_url(self, key: str, expires_in: int = 3600) -> str:
"""
Generate a pre-signed URL for downloading an object.
Args:
key: Object key (path) within the bucket
expires_in: URL validity duration in seconds (default: 1 hour)
Returns:
Pre-signed URL string.
Raises:
ArchiveStorageError: If generation fails
"""
try:
return self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket, "Key": key},
ExpiresIn=expires_in,
)
except ClientError as e:
raise ArchiveStorageError(f"Failed to generate pre-signed URL for '{key}': {e}")
def list_objects(self, prefix: str) -> list[str]:
"""
List objects under a given prefix.
Args:
prefix: Object key prefix to filter by
Returns:
List of object keys matching the prefix
"""
keys = []
paginator = self.client.get_paginator("list_objects_v2")
try:
for page in paginator.paginate(Bucket=self.bucket, Prefix=prefix):
for obj in page.get("Contents", []):
keys.append(obj["Key"])
except ClientError as e:
raise ArchiveStorageError(f"Failed to list objects with prefix '{prefix}': {e}")
return keys
@staticmethod
def _content_md5(data: bytes) -> str:
"""Calculate base64-encoded MD5 for Content-MD5 header."""
return base64.b64encode(hashlib.md5(data).digest()).decode()
@staticmethod
def serialize_to_jsonl_gz(records: list[dict[str, Any]]) -> bytes:
"""
Serialize records to gzipped JSONL format.
Args:
records: List of dictionaries to serialize
Returns:
Gzipped JSONL bytes
"""
lines = []
for record in records:
# Convert datetime objects to ISO format strings
serialized = ArchiveStorage._serialize_record(record)
lines.append(orjson.dumps(serialized))
jsonl_content = b"\n".join(lines)
if jsonl_content:
jsonl_content += b"\n"
return gzip.compress(jsonl_content)
@staticmethod
def deserialize_from_jsonl_gz(data: bytes) -> list[dict[str, Any]]:
"""
Deserialize gzipped JSONL data to records.
Args:
data: Gzipped JSONL bytes
Returns:
List of dictionaries
"""
jsonl_content = gzip.decompress(data)
records = []
for line in jsonl_content.splitlines():
if line:
records.append(orjson.loads(line))
return records
@staticmethod
def _serialize_record(record: dict[str, Any]) -> dict[str, Any]:
"""Serialize a single record, converting special types."""
def _serialize(item: Any) -> Any:
if isinstance(item, datetime.datetime):
return item.isoformat()
if isinstance(item, dict):
return {key: _serialize(value) for key, value in item.items()}
if isinstance(item, list):
return [_serialize(value) for value in item]
return item
return cast(dict[str, Any], _serialize(record))
@staticmethod
def compute_checksum(data: bytes) -> str:
"""Compute MD5 checksum of data."""
return hashlib.md5(data).hexdigest()
# Singleton instance (lazy initialization)
_archive_storage: ArchiveStorage | None = None
_export_storage: ArchiveStorage | None = None
def get_archive_storage() -> ArchiveStorage:
"""
Get the archive storage singleton instance.
Returns:
ArchiveStorage instance
Raises:
ArchiveStorageNotConfiguredError: If archive storage is not configured
"""
global _archive_storage
if _archive_storage is None:
archive_bucket = dify_config.ARCHIVE_STORAGE_ARCHIVE_BUCKET
if not archive_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive storage bucket is not configured. Required: ARCHIVE_STORAGE_ARCHIVE_BUCKET"
)
_archive_storage = ArchiveStorage(bucket=archive_bucket)
return _archive_storage
def get_export_storage() -> ArchiveStorage:
"""
Get the export storage singleton instance.
Returns:
ArchiveStorage instance
"""
global _export_storage
if _export_storage is None:
export_bucket = dify_config.ARCHIVE_STORAGE_EXPORT_BUCKET
if not export_bucket:
raise ArchiveStorageNotConfiguredError(
"Archive export bucket is not configured. Required: ARCHIVE_STORAGE_EXPORT_BUCKET"
)
_export_storage = ArchiveStorage(bucket=export_bucket)
return _export_storage

View File

@ -1,4 +1,5 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
@ -108,8 +109,11 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
# Explicit log_exception call removed to avoid duplicate log entries
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
return data, status_code

View File

@ -11,6 +11,9 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '00bacef91f18'
down_revision = '8ec536f3c800'
@ -20,17 +23,31 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
batch_op.drop_column('description_str')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
# ### end Alembic commands ###

View File

@ -7,10 +7,14 @@ Create Date: 2024-01-10 04:40:57.257824
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '114eed84c228'
down_revision = 'c71211c8f604'
@ -28,7 +32,13 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
else:
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
# ### end Alembic commands ###

View File

@ -11,6 +11,9 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '161cadc1af8d'
down_revision = '7e6a8693e07a'
@ -20,9 +23,16 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
else:
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
# ### end Alembic commands ###

View File

@ -9,6 +9,11 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
@ -18,30 +23,58 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=True)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=False)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -8,6 +8,7 @@ Create Date: 2024-11-01 04:34:23.816198
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3f6769a94a3'

View File

@ -28,45 +28,85 @@ def upgrade():
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
if _is_pg(conn):
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
else:
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###

View File

@ -49,33 +49,57 @@ def upgrade():
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=False)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=True)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
if _is_pg(conn):
with op.batch_alter_table('messages', schema=None) as batch_op:

View File

@ -86,30 +86,57 @@ def upgrade():
def migrate_existing_provider_models_data():
"""migrate provider_models table data to provider_model_credentials"""
# Define table structure for data manipulatio
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
conn = op.get_bind()
# Define table structure for data manipulation
if _is_pg(conn):
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', sa.Text()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
else:
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
if _is_pg(conn):
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', sa.Text()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
else:
provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
column('provider_name', sa.String()),
column('model_name', sa.String()),
column('model_type', sa.String()),
column('credential_name', sa.String()),
column('encrypted_config', models.types.LongText()),
column('created_at', sa.DateTime()),
column('updated_at', sa.DateTime())
)
# Get database connection
@ -156,8 +183,14 @@ def migrate_existing_provider_models_data():
def downgrade():
# Re-add encrypted_config column to provider_models table
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', sa.Text(), nullable=True))
else:
with op.batch_alter_table('provider_models', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypted_config', models.types.LongText(), nullable=True))
if not context.is_offline_mode():
# Migrate data back from provider_model_credentials to provider_models

View File

@ -8,6 +8,7 @@ Create Date: 2025-08-20 17:47:17.015695
from alembic import op
import models as models
import sqlalchemy as sa
from libs.uuid_utils import uuidv7
def _is_pg(conn):

View File

@ -9,6 +9,8 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@ -21,7 +23,12 @@ depends_on = None
def upgrade():
# Add encrypted_headers column to tool_mcp_providers table
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
conn = op.get_bind()
if _is_pg(conn):
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', sa.Text(), nullable=True))
else:
op.add_column('tool_mcp_providers', sa.Column('encrypted_headers', models.types.LongText(), nullable=True))
def downgrade():

View File

@ -44,7 +44,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
)
if _is_pg(conn):
op.create_table('datasource_oauth_tenant_params',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -71,7 +70,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_oauth_tenant_config_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='datasource_oauth_tenant_config_unique')
)
if _is_pg(conn):
op.create_table('datasource_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -106,7 +104,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', 'name', name='datasource_provider_unique_name')
)
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.create_index('datasource_provider_auth_type_provider_idx', ['tenant_id', 'plugin_id', 'provider'], unique=False)
@ -136,7 +133,6 @@ def upgrade():
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_pipeline_execution_log_pkey')
)
with op.batch_alter_table('document_pipeline_execution_logs', schema=None) as batch_op:
batch_op.create_index('document_pipeline_execution_logs_document_id_idx', ['document_id'], unique=False)
@ -178,7 +174,6 @@ def upgrade():
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
)
if _is_pg(conn):
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -198,6 +193,7 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
else:
# MySQL: Use compatible syntax
op.create_table('pipeline_customized_templates',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
@ -215,7 +211,6 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
)
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
@ -241,7 +236,6 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_recommended_plugin_pkey')
)
if _is_pg(conn):
op.create_table('pipelines',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -272,7 +266,6 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
)
if _is_pg(conn):
op.create_table('workflow_draft_variable_files',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -299,7 +292,6 @@ def upgrade():
sa.Column('value_type', sa.String(20), nullable=False),
sa.PrimaryKeyConstraint('id', name=op.f('workflow_draft_variable_files_pkey'))
)
if _is_pg(conn):
op.create_table('workflow_node_execution_offload',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
@ -324,7 +316,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name=op.f('workflow_node_execution_offload_pkey')),
sa.UniqueConstraint('node_execution_id', 'type', name=op.f('workflow_node_execution_offload_node_execution_id_key'))
)
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
@ -351,7 +342,6 @@ def upgrade():
comment='Indicates whether the current value is the default for a conversation variable. Always `FALSE` for other types of variables.',)
)
batch_op.create_index('workflow_draft_variable_file_id_idx', ['file_id'], unique=False)
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))

View File

@ -9,6 +9,8 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
@ -31,9 +33,15 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', sa.UUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', sa.UUID(), autoincrement=False, nullable=True))
else:
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), autoincrement=False, nullable=False))
batch_op.add_column(sa.Column('updated_by', models.types.StringUUID(), autoincrement=False, nullable=True))
# ### end Alembic commands ###

View File

@ -9,6 +9,7 @@ Create Date: 2025-10-22 16:11:31.805407
from alembic import op
import models as models
import sqlalchemy as sa
from libs.uuid_utils import uuidv7
def _is_pg(conn):
return conn.dialect.name == "postgresql"

View File

@ -105,7 +105,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_trigger_oauth_tenant_client')
)
if _is_pg(conn):
op.create_table('trigger_subscriptions',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
@ -144,7 +143,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='trigger_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'provider_id', 'name', name='unique_trigger_provider')
)
with op.batch_alter_table('trigger_subscriptions', schema=None) as batch_op:
batch_op.create_index('idx_trigger_providers_endpoint', ['endpoint_id'], unique=True)
batch_op.create_index('idx_trigger_providers_tenant_endpoint', ['tenant_id', 'endpoint_id'], unique=False)
@ -178,7 +176,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_plugin_trigger_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node_subscription')
)
with op.batch_alter_table('workflow_plugin_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_plugin_trigger_tenant_subscription_idx', ['tenant_id', 'subscription_id', 'event_name'], unique=False)
@ -210,7 +207,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id', name='workflow_schedule_plan_pkey'),
sa.UniqueConstraint('app_id', 'node_id', name='uniq_app_node')
)
with op.batch_alter_table('workflow_schedule_plans', schema=None) as batch_op:
batch_op.create_index('workflow_schedule_plan_next_idx', ['next_run_at'], unique=False)
@ -268,7 +264,6 @@ def upgrade():
sa.Column('finished_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id', name='workflow_trigger_log_pkey')
)
with op.batch_alter_table('workflow_trigger_logs', schema=None) as batch_op:
batch_op.create_index('workflow_trigger_log_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_trigger_log_status_idx', ['status'], unique=False)
@ -304,7 +299,6 @@ def upgrade():
sa.UniqueConstraint('app_id', 'node_id', name='uniq_node'),
sa.UniqueConstraint('webhook_id', name='uniq_webhook_id')
)
with op.batch_alter_table('workflow_webhook_triggers', schema=None) as batch_op:
batch_op.create_index('workflow_webhook_trigger_tenant_idx', ['tenant_id'], unique=False)

View File

@ -11,6 +11,9 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '23db93619b9d'
down_revision = '8ae9bc661daa'
@ -20,8 +23,14 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', sa.Text(), nullable=True))
else:
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('message_files', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -62,8 +62,14 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True))
else:
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('annotation_reply', models.types.LongText(), autoincrement=False, nullable=True))
with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op:
batch_op.drop_index('app_annotation_settings_app_idx')

View File

@ -11,6 +11,9 @@ from alembic import op
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '2a3aebbbf4bb'
down_revision = 'c031d46af369'
@ -20,8 +23,14 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', sa.Text(), nullable=True))
else:
with op.batch_alter_table('apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tracing', models.types.LongText(), nullable=True))
# ### end Alembic commands ###

View File

@ -7,10 +7,14 @@ Create Date: 2023-09-22 15:41:01.243183
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '2e9819ca5b28'
down_revision = 'ab23c11305d4'
@ -20,19 +24,35 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
else:
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=True))
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
batch_op.drop_column('dataset_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
else:
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
batch_op.add_column(sa.Column('dataset_id', models.types.StringUUID(), autoincrement=False, nullable=True))
batch_op.drop_index('api_token_tenant_idx')
batch_op.drop_column('tenant_id')
# ### end Alembic commands ###

View File

@ -7,10 +7,14 @@ Create Date: 2024-03-07 08:30:29.133614
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '42e85ed5564d'
down_revision = 'f9107f83abab'
@ -20,31 +24,59 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
else:
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=True)
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=False)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
with op.batch_alter_table('conversations', schema=None) as batch_op:
batch_op.alter_column('model_id',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('model_provider',
existing_type=sa.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('app_model_config_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -6,10 +6,14 @@ Create Date: 2024-01-12 03:42:27.362415
"""
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '4829e54d2fee'
down_revision = '114eed84c228'
@ -19,21 +23,39 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=True)
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=postgresql.UUID(),
nullable=True)
else:
# MySQL: Use compatible syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=False)
conn = op.get_bind()
if _is_pg(conn):
# PostgreSQL: Keep original syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
# MySQL: Use compatible syntax
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.alter_column('message_chain_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -6,10 +6,14 @@ Create Date: 2024-03-14 04:54:56.679506
"""
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '563cf8bf777b'
down_revision = 'b5429b71023c'
@ -19,19 +23,35 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=True)
else:
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=postgresql.UUID(),
nullable=False)
else:
with op.batch_alter_table('tool_files', schema=None) as batch_op:
batch_op.alter_column('conversation_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@ -48,9 +48,12 @@ def upgrade():
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
if _is_pg(conn):
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))
else:
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', models.types.StringUUID(), nullable=True))
# ### end Alembic commands ###

Some files were not shown because too many files have changed in this diff Show More