mirror of
https://github.com/langgenius/dify.git
synced 2026-05-24 02:47:53 +08:00
Compare commits
56 Commits
dependabot
...
fix/model-
| Author | SHA1 | Date | |
|---|---|---|---|
| 521545d52e | |||
| 72ee50c74f | |||
| 8d99326fb3 | |||
| 2a0c098857 | |||
| 790ca72627 | |||
| 4d8b6c7dc0 | |||
| 473c945839 | |||
| a698c60b29 | |||
| 24bab5fb2a | |||
| 93b7a81071 | |||
| 157e6244dd | |||
| 964aaad7ed | |||
| 92181dbe09 | |||
| 30deef45d9 | |||
| ee28074390 | |||
| 1fb491337b | |||
| 82b0a03f5a | |||
| 6185016910 | |||
| b4f5f4869f | |||
| 7ecbed3b04 | |||
| 5b58defd62 | |||
| 73196de5e1 | |||
| ea5e487d3c | |||
| f19702f76c | |||
| 092c8bca81 | |||
| c50d504c44 | |||
| 1b4356b66a | |||
| 7f633622aa | |||
| 66f5ab4cfc | |||
| 0cf9597f52 | |||
| 60cd346fa6 | |||
| 56d4d54c16 | |||
| 9f9cb4d17e | |||
| 7d0d9019d8 | |||
| d646bcf257 | |||
| e3b45a48eb | |||
| 848c15a265 | |||
| be8627233d | |||
| 1fe8b7fb1d | |||
| 5a585c8618 | |||
| cc9b90a5ae | |||
| b64d4b53ca | |||
| 5cdf4e405b | |||
| 7cb14cb4cc | |||
| de38bba99b | |||
| f04d809426 | |||
| 7ed3c7c500 | |||
| 77f1aeb1ac | |||
| 7bc5c89e3c | |||
| 718ab8433e | |||
| 8f197c5a0a | |||
| 0295862d0d | |||
| 2b2a5824c1 | |||
| 468cc19e68 | |||
| 77333e57a7 | |||
| f52491e2c1 |
44
.github/CODEOWNERS
vendored
44
.github/CODEOWNERS
vendored
@ -92,28 +92,28 @@
|
||||
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @WH-2099
|
||||
|
||||
# Backend - Trigger/Schedule/Webhook
|
||||
/api/controllers/trigger/ @Mairuis
|
||||
/api/controllers/console/app/workflow_trigger.py @Mairuis
|
||||
/api/controllers/console/workspace/trigger_providers.py @Mairuis
|
||||
/api/core/trigger/ @Mairuis
|
||||
/api/core/app/layers/trigger_post_layer.py @Mairuis
|
||||
/api/services/trigger/ @Mairuis
|
||||
/api/models/trigger.py @Mairuis
|
||||
/api/fields/workflow_trigger_fields.py @Mairuis
|
||||
/api/repositories/workflow_trigger_log_repository.py @Mairuis
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis
|
||||
/api/libs/schedule_utils.py @Mairuis
|
||||
/api/services/workflow/scheduler.py @Mairuis
|
||||
/api/schedule/trigger_provider_refresh_task.py @Mairuis
|
||||
/api/schedule/workflow_schedule_task.py @Mairuis
|
||||
/api/tasks/trigger_processing_tasks.py @Mairuis
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis
|
||||
/api/tasks/workflow_schedule_tasks.py @Mairuis
|
||||
/api/tasks/workflow_cfs_scheduler/ @Mairuis
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis
|
||||
/api/controllers/trigger/ @CourTeous33
|
||||
/api/controllers/console/app/workflow_trigger.py @CourTeous33
|
||||
/api/controllers/console/workspace/trigger_providers.py @CourTeous33
|
||||
/api/core/trigger/ @CourTeous33
|
||||
/api/core/app/layers/trigger_post_layer.py @CourTeous33
|
||||
/api/services/trigger/ @CourTeous33
|
||||
/api/models/trigger.py @CourTeous33
|
||||
/api/fields/workflow_trigger_fields.py @CourTeous33
|
||||
/api/repositories/workflow_trigger_log_repository.py @CourTeous33
|
||||
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @CourTeous33
|
||||
/api/libs/schedule_utils.py @CourTeous33
|
||||
/api/services/workflow/scheduler.py @CourTeous33
|
||||
/api/schedule/trigger_provider_refresh_task.py @CourTeous33
|
||||
/api/schedule/workflow_schedule_task.py @CourTeous33
|
||||
/api/tasks/trigger_processing_tasks.py @CourTeous33
|
||||
/api/tasks/trigger_subscription_refresh_tasks.py @CourTeous33
|
||||
/api/tasks/workflow_schedule_tasks.py @CourTeous33
|
||||
/api/tasks/workflow_cfs_scheduler/ @CourTeous33
|
||||
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @CourTeous33
|
||||
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @CourTeous33
|
||||
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @CourTeous33
|
||||
/api/events/event_handlers/sync_webhook_when_app_created.py @CourTeous33
|
||||
|
||||
# Backend - Async Workflow
|
||||
/api/services/async_workflow_service.py @Mairuis
|
||||
|
||||
4
.github/actions/setup-web/action.yml
vendored
4
.github/actions/setup-web/action.yml
vendored
@ -5,11 +5,11 @@ runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5
|
||||
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
|
||||
with:
|
||||
run_install: false
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
|
||||
uses: voidzero-dev/setup-vp@ca1c46663915d6c1042ae23bd39ab85718bfb0fa # v1.10.0
|
||||
with:
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
|
||||
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@ -195,7 +195,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
|
||||
12
.github/workflows/build-push.yml
vendored
12
.github/workflows/build-push.yml
vendored
@ -35,15 +35,15 @@ jobs:
|
||||
- service_name: "build-api-amd64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
- service_name: "build-web-amd64"
|
||||
@ -117,8 +117,8 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "validate-api-amd64"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
- service_name: "validate-web-amd64"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
|
||||
18
.github/workflows/docker-build.yml
vendored
18
.github/workflows/docker-build.yml
vendored
@ -6,6 +6,12 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- api/Dockerfile.dockerignore
|
||||
- api/pyproject.toml
|
||||
- api/uv.lock
|
||||
- dify-agent/pyproject.toml
|
||||
- dify-agent/README.md
|
||||
- dify-agent/src/**
|
||||
- web/Dockerfile
|
||||
|
||||
concurrency:
|
||||
@ -25,13 +31,13 @@ jobs:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
runs_on: depot-ubuntu-24.04-4
|
||||
@ -64,8 +70,8 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
context: "{{defaultContext}}"
|
||||
file: "api/Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
|
||||
@ -63,8 +63,8 @@ jobs:
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
--base "$GITHUB_WORKSPACE/base_report.json" \
|
||||
< "$GITHUB_WORKSPACE/pr_report.json")"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -65,6 +65,9 @@ jobs:
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
# Keep fork-PR comments correct while the trusted workflow_run job is
|
||||
# still using the default-branch renderer, which resolves --base from api/.
|
||||
cp /tmp/pyrefly_report_base.json api/base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
@ -77,6 +80,7 @@ jobs:
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
api/base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
|
||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -47,6 +47,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Response Contract Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
|
||||
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: make type-check-core
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@ -158,7 +158,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119
|
||||
uses: anthropics/claude-code-action@1dc994ee7a008f0ecc866d9ac23ef036b7229f84 # v1.0.127
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
6
.github/workflows/web-tests.yml
vendored
6
.github/workflows/web-tests.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
run: vp test run --reporter=blob --reporter=minimal --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
|
||||
with:
|
||||
directory: packages/dify-ui/coverage
|
||||
flags: dify-ui
|
||||
|
||||
11
Makefile
11
Makefile
@ -75,13 +75,19 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev ruff format ./api
|
||||
@uv run --project api --dev ruff check --fix ./api
|
||||
@$(MAKE) api-contract-lint
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
api-contract-lint:
|
||||
@echo "🔎 Linting Flask response contracts..."
|
||||
@uv run --project api --dev python api/dev/lint_response_contracts.py
|
||||
@echo "✅ Response contract lint complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@ -191,6 +197,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@ -204,4 +211,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all
|
||||
|
||||
@ -657,6 +657,7 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
PLUGIN_MAX_PACKAGE_SIZE=15728640
|
||||
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
|
||||
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
|
||||
|
||||
# Marketplace configuration
|
||||
@ -767,6 +768,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -195,6 +195,7 @@ Before opening a PR / submitting:
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
|
||||
@ -22,9 +22,11 @@ RUN apt-get update \
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
|
||||
COPY api/pyproject.toml api/uv.lock ./
|
||||
COPY api/providers ./providers
|
||||
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
|
||||
COPY dify-agent/src /app/dify-agent/src
|
||||
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# production stage
|
||||
@ -108,10 +110,10 @@ RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
|
||||
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
|
||||
|
||||
# Copy source code
|
||||
COPY --chown=dify:dify . /app/api/
|
||||
COPY --chown=dify:dify api /app/api/
|
||||
|
||||
# Prepare entrypoint script
|
||||
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
|
||||
COPY --chown=dify:dify --chmod=755 api/docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
|
||||
ARG COMMIT_SHA
|
||||
|
||||
25
api/Dockerfile.dockerignore
Normal file
25
api/Dockerfile.dockerignore
Normal file
@ -0,0 +1,25 @@
|
||||
*
|
||||
|
||||
!api/
|
||||
!api/**
|
||||
!dify-agent/
|
||||
!dify-agent/pyproject.toml
|
||||
!dify-agent/README.md
|
||||
!dify-agent/src/
|
||||
!dify-agent/src/**
|
||||
|
||||
api/.venv
|
||||
api/.venv/**
|
||||
api/.env
|
||||
api/*.env.*
|
||||
api/.idea
|
||||
api/.mypy_cache
|
||||
api/.ruff_cache
|
||||
api/storage/generate_files/*
|
||||
api/storage/privkeys/*
|
||||
api/storage/tools/*
|
||||
api/storage/upload_files/*
|
||||
api/logs
|
||||
api/*.log*
|
||||
**/__pycache__
|
||||
**/*.pyc
|
||||
@ -49,6 +49,7 @@ class AgentBackendModelConfig(BaseModel):
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@ -138,6 +139,7 @@ class AgentBackendRunRequestBuilder:
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -11,6 +11,7 @@ from configs import dify_config
|
||||
from core.helper import encrypter
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from core.tools.utils.system_encryption import encrypt_system_params
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
@ -20,7 +21,6 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from models.tools import ToolOAuthSystemClient
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
|
||||
@ -25,6 +25,7 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
super().__init__(settings_cls)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -90,6 +91,7 @@ class DifyConfig(
|
||||
# Thanks for your concentration and consideration.
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from configs.extra.agent_backend_config import AgentBackendConfig
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
@ -5,6 +6,7 @@ from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
AgentBackendConfig,
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
|
||||
23
api/configs/extra/agent_backend_config.py
Normal file
23
api/configs/extra/agent_backend_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AgentBackendConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the Agent backend runtime integration.
|
||||
"""
|
||||
|
||||
AGENT_BACKEND_BASE_URL: str | None = Field(
|
||||
description="Base URL for the Dify Agent backend service.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_USE_FAKE: bool = Field(
|
||||
description="Use the deterministic in-process fake Agent backend client.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
@ -265,6 +265,11 @@ class PluginConfig(BaseSettings):
|
||||
default=60 * 60,
|
||||
)
|
||||
|
||||
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
|
||||
description="TTL in seconds for caching tenant plugin model providers in Redis",
|
||||
default=60 * 60 * 24,
|
||||
)
|
||||
|
||||
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
|
||||
description="Maximum allowed size (bytes) for plugin-generated files",
|
||||
default=50 * 1024 * 1024,
|
||||
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.fields import FieldInfo
|
||||
@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource):
|
||||
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
if not isinstance(self.remote_configs, dict):
|
||||
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource):
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to parse config: {e}")
|
||||
|
||||
@override
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
field_value = self.remote_configs.get(field_name)
|
||||
if field_value is None:
|
||||
|
||||
@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
from typing import Any, Protocol, final, override, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -133,10 +133,12 @@ class NullAppContext(AppContext):
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
@ -146,6 +148,7 @@ class NullAppContext(AppContext):
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
@ -6,7 +6,7 @@ import contextvars
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
from typing import Any, final, override
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
@ -30,15 +30,18 @@ class FlaskAppContext(AppContext):
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
@override
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
@override
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
@override
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
|
||||
@ -36,6 +36,24 @@ class FileInfo(BaseModel):
|
||||
size: int
|
||||
|
||||
|
||||
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
if isinstance(query_string, bytes):
|
||||
raw_query = query_string.decode()
|
||||
else:
|
||||
raw_query = query_string
|
||||
if not raw_query:
|
||||
return decoded_url
|
||||
|
||||
if decoded_url.endswith(("?", "&")):
|
||||
separator = ""
|
||||
elif urllib.parse.urlsplit(decoded_url).query:
|
||||
separator = "&"
|
||||
else:
|
||||
separator = "?"
|
||||
return f"{decoded_url}{separator}{raw_query}"
|
||||
|
||||
|
||||
def guess_file_info_from_response(response: httpx.Response):
|
||||
url = str(response.url)
|
||||
# Try to extract filename from URL
|
||||
|
||||
@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
|
||||
@ -269,12 +269,12 @@ class AnnotationApi(Resource):
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return result, 204
|
||||
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return "", 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(str(app_id))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
|
||||
@ -633,7 +633,7 @@ class AppApi(Resource):
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
|
||||
@ -29,9 +29,6 @@ from fields.conversation_fields import (
|
||||
from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ResultResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@ -77,7 +74,6 @@ register_schema_models(
|
||||
ConversationMessageDetailResponse,
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
)
|
||||
@ -194,7 +190,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
@ -347,7 +343,7 @@ class ChatConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
|
||||
@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
|
||||
@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from typing import Any, cast
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -30,26 +31,10 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
file_info_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
@ -61,58 +46,6 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
|
||||
|
||||
content_fields_copy = content_fields.copy()
|
||||
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
|
||||
content_model = get_or_create_model("DatasetContent", content_fields_copy)
|
||||
|
||||
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
|
||||
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
|
||||
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
|
||||
|
||||
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
@ -208,9 +141,165 @@ class ConsoleDatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetListItemResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str]
|
||||
|
||||
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetListItemResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetQueryFileInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetQueryContentResponse(ResponseModel):
|
||||
content_type: str
|
||||
content: str
|
||||
file_info: DatasetQueryFileInfoResponse | None = None
|
||||
|
||||
|
||||
class DatasetQueryDetailResponse(ResponseModel):
|
||||
id: str
|
||||
queries: list[DatasetQueryContentResponse]
|
||||
source: str
|
||||
source_app_id: str | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DatasetQueryListResponse(ResponseModel):
|
||||
data: list[DatasetQueryDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class RelatedAppResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
icon_url: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_icon_url(self) -> "RelatedAppResponse":
|
||||
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
|
||||
return self
|
||||
|
||||
|
||||
class RelatedAppListResponse(ResponseModel):
|
||||
data: list[RelatedAppResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DocumentStatusResponse(ResponseModel):
|
||||
id: str
|
||||
indexing_status: str
|
||||
processing_started_at: int | None
|
||||
parsing_completed_at: int | None
|
||||
cleaning_completed_at: int | None
|
||||
splitting_completed_at: int | None
|
||||
completed_at: int | None
|
||||
paused_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
@field_validator(
|
||||
"processing_started_at",
|
||||
"parsing_completed_at",
|
||||
"cleaning_completed_at",
|
||||
"splitting_completed_at",
|
||||
"completed_at",
|
||||
"paused_at",
|
||||
"stopped_at",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentStatusListResponse(ResponseModel):
|
||||
data: list[DocumentStatusResponse]
|
||||
|
||||
|
||||
class ErrorDocsResponse(DocumentStatusListResponse):
|
||||
total: int
|
||||
|
||||
|
||||
class IndexingEstimatePreviewItemResponse(ResponseModel):
|
||||
content: str
|
||||
child_chunks: list[str] | None = None
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class IndexingEstimateResponse(ResponseModel):
|
||||
total_segments: int
|
||||
preview: list[IndexingEstimatePreviewItemResponse]
|
||||
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
|
||||
|
||||
|
||||
class RetrievalSettingResponse(ResponseModel):
|
||||
retrieval_method: list[str]
|
||||
|
||||
|
||||
class PartialMemberListResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
class AutoDisableLogsResponse(ResponseModel):
|
||||
document_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetQueryListResponse,
|
||||
IndexingEstimateResponse,
|
||||
RelatedAppListResponse,
|
||||
DocumentStatusListResponse,
|
||||
ErrorDocsResponse,
|
||||
RetrievalSettingResponse,
|
||||
PartialMemberListResponse,
|
||||
AutoDisableLogsResponse,
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@ -293,17 +382,8 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
class DatasetListApi(Resource):
|
||||
@console_ns.doc("get_datasets")
|
||||
@console_ns.doc(description="Get list of datasets")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"ids": "Filter by dataset IDs (list)",
|
||||
"keyword": "Search keyword",
|
||||
"tag_ids": "Filter by tag IDs (list)",
|
||||
"include_all": "Include all datasets (default: false)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Datasets retrieved successfully")
|
||||
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
|
||||
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -342,7 +422,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
|
||||
partial_members_map: dict[str, list[str]] = {}
|
||||
if dataset_ids:
|
||||
@ -379,12 +459,12 @@ class DatasetListApi(Resource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -413,7 +493,7 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -421,7 +501,11 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("get_dataset")
|
||||
@console_ns.doc(description="Get dataset details")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -437,7 +521,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@ -470,7 +554,11 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -506,7 +594,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
@ -535,7 +623,7 @@ class DatasetApi(Resource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@ -567,7 +655,11 @@ class DatasetQueryApi(Resource):
|
||||
@console_ns.doc("get_dataset_queries")
|
||||
@console_ns.doc(description="Get dataset query history")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Query history retrieved successfully",
|
||||
console_ns.models[DatasetQueryListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -589,20 +681,24 @@ class DatasetQueryApi(Resource):
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
"data": marshal(dataset_queries, dataset_query_detail_model),
|
||||
"data": dataset_queries,
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetQueryListResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/indexing-estimate")
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
@console_ns.doc("estimate_dataset_indexing")
|
||||
@console_ns.doc(description="Estimate dataset indexing cost")
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing estimate calculated successfully",
|
||||
console_ns.models[IndexingEstimateResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -699,11 +795,14 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@console_ns.doc("get_dataset_related_apps")
|
||||
@console_ns.doc(description="Get applications related to dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Related apps retrieved successfully",
|
||||
console_ns.models[RelatedAppListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list_model)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -724,7 +823,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
|
||||
@ -732,7 +831,11 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@console_ns.doc("get_dataset_indexing_status")
|
||||
@console_ns.doc(description="Get dataset indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing status retrieved successfully",
|
||||
console_ns.models[DocumentStatusListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -778,9 +881,8 @@ class DatasetIndexingStatusApi(Resource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data, 200
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys")
|
||||
@ -873,7 +975,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||
@ -907,13 +1009,18 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting")
|
||||
@console_ns.doc(description="Get dataset retrieval settings")
|
||||
@console_ns.response(200, "Retrieval settings retrieved successfully")
|
||||
@console_ns.response(
|
||||
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
@ -921,12 +1028,19 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting_mock")
|
||||
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@console_ns.doc(params={"vector_type": "Vector store type"})
|
||||
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Mock retrieval settings retrieved successfully",
|
||||
console_ns.models[RetrievalSettingResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
@ -934,7 +1048,7 @@ class DatasetErrorDocs(Resource):
|
||||
@console_ns.doc("get_dataset_error_docs")
|
||||
@console_ns.doc(description="Get dataset error documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Error documents retrieved successfully")
|
||||
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -946,7 +1060,7 @@ class DatasetErrorDocs(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
@ -954,7 +1068,11 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@console_ns.doc("get_dataset_permission_users")
|
||||
@console_ns.doc(description="Get dataset permission user list")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Permission users retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Permission users retrieved successfully",
|
||||
console_ns.models[PartialMemberListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -973,9 +1091,7 @@ class DatasetPermissionUserListApi(Resource):
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
"data": partial_members_list,
|
||||
}, 200
|
||||
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
|
||||
@ -983,7 +1099,11 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
@console_ns.doc("get_dataset_auto_disable_logs")
|
||||
@console_ns.doc(description="Get dataset auto disable logs")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Auto disable logs retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Auto disable logs retrieved successfully",
|
||||
console_ns.models[AutoDisableLogsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -993,4 +1113,4 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
|
||||
|
||||
@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
@ -966,7 +966,7 @@ class DocumentApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot pause completed document.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Document is not in paused status.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
|
||||
@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segments(segment_ids, document, dataset)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -22,7 +26,12 @@ from services.metadata_service import MetadataService
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -31,7 +40,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -44,18 +53,22 @@ class DatasetMetadataCreateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -64,7 +77,7 @@ class DatasetMetadataApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -79,7 +92,7 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -96,7 +109,8 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return {"result": "success"}, 204
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
@ -105,9 +119,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -116,7 +135,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -130,7 +149,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -140,7 +160,10 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -153,4 +176,5 @@ class DocumentMetadataEditApi(Resource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
# Frontend callers only await success and invalidate caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
||||
@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
|
||||
@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource):
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -28,7 +28,32 @@ class FeatureApi(Resource):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
).model_dump()
|
||||
payload.pop("vector_space", None)
|
||||
return payload
|
||||
|
||||
|
||||
@console_ns.route("/features/vector-space")
|
||||
class FeatureVectorSpaceApi(Resource):
|
||||
@console_ns.doc("get_tenant_feature_vector_space")
|
||||
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[LimitationModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -34,7 +33,7 @@ class GetRemoteFileInfo(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
|
||||
@ -56,6 +56,12 @@ from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import (
|
||||
ChangeEmailNewEmailToken,
|
||||
ChangeEmailNewEmailVerifiedToken,
|
||||
ChangeEmailOldEmailToken,
|
||||
ChangeEmailOldEmailVerifiedToken,
|
||||
)
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@ -620,8 +626,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
account = None
|
||||
user_email = None
|
||||
account = current_user
|
||||
user_email = current_user.email
|
||||
email_for_sending = args.email.lower()
|
||||
# Default to the initial phase; any legacy/unexpected client input is
|
||||
# coerced back to `old_email` so we never trust the caller to declare
|
||||
@ -636,24 +642,18 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# The token used to request a new-email code must come from the
|
||||
# old-email verification step. This prevents the bypass described
|
||||
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
if not isinstance(reset_data, ChangeEmailOldEmailVerifiedToken):
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
if not reset_data.is_bound_to_account(current_user.id):
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.email
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
email_for_sending = account.email
|
||||
user_email = account.email
|
||||
if email_for_sending != current_user.email.lower():
|
||||
raise InvalidEmailError()
|
||||
email_for_sending = current_user.email
|
||||
|
||||
token = AccountService.send_change_email_email(
|
||||
account=account,
|
||||
@ -674,6 +674,7 @@ class ChangeEmailCheckApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
@ -686,42 +687,26 @@ class ChangeEmailCheckApi(Resource):
|
||||
token_data = AccountService.get_change_email_data(args.token)
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
if not token_data.is_bound_to_account(current_user.id):
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
normalized_token_email = token_data.email.lower()
|
||||
if user_email != normalized_token_email:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args.code != token_data.get("code"):
|
||||
if args.code != token_data.code:
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
# Only advance tokens that were minted by the matching send-code step;
|
||||
# refuse tokens that have already progressed or lack a phase marker so
|
||||
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
|
||||
# is strictly enforced.
|
||||
phase_transitions = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
if isinstance(token_data, ChangeEmailOldEmailToken | ChangeEmailNewEmailToken):
|
||||
refreshed_token_data = token_data.promote()
|
||||
else:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token that carries the
|
||||
# upgraded phase so later steps can check it.
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
new_token = AccountService.generate_change_email_token(refreshed_token_data, current_user)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
|
||||
@ -746,27 +731,22 @@ class ChangeEmailResetApi(Resource):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
if not reset_data.is_bound_to_account(current_user.id):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Only tokens that completed both verification phases may be used to
|
||||
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
|
||||
# the initial send-code step could be replayed directly here.
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
if not isinstance(reset_data, ChangeEmailNewEmailVerifiedToken):
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Bind the new email to the token that was mailed and verified, so a
|
||||
# verified token cannot be reused with a different `new_email` value.
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
if reset_data.email.lower() != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
if current_user.email.lower() != reset_data.old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
# Revoke only after all checks pass so failed attempts don't burn a
|
||||
|
||||
@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
|
||||
@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console import console_ns
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.plugin_service import PluginService
|
||||
from fields.base import ResponseModel
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
@ -22,7 +23,6 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_parameter_service import PluginParameterService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
|
||||
class ParserList(BaseModel):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Literal, override
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -76,11 +76,13 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_enum_models,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@ -17,9 +21,10 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -119,6 +124,21 @@ class TagUnbindingPayload(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class KnowledgeTagResponse(ResponseModel):
|
||||
model_config = ConfigDict(coerce_numbers_to_str=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
# TODO: The public Service API docs expose binding_count as string|null.
|
||||
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
|
||||
binding_count: str | None = None
|
||||
|
||||
|
||||
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
@ -127,6 +147,29 @@ class DatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
# todo: duplicate code, but the partial_member_list has different nullability
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetBoundTagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class DatasetBoundTagListResponse(ResponseModel):
|
||||
data: list[DatasetBoundTagResponse]
|
||||
total: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
@ -137,9 +180,17 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
SimpleResultResponse,
|
||||
KnowledgeTagResponse,
|
||||
KnowledgeTagListResponse,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetBoundTagListResponse,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets")
|
||||
@ -154,9 +205,18 @@ class DatasetListApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Datasets retrieved successfully",
|
||||
service_api_ns.models[DatasetListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict())
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = DatasetListQuery.model_validate(query_params)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
@ -175,22 +235,17 @@ class DatasetListApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
for item in data:
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True # type: ignore
|
||||
item["embedding_available"] = True
|
||||
else:
|
||||
item["embedding_available"] = False # type: ignore
|
||||
item["embedding_available"] = False
|
||||
else:
|
||||
item["embedding_available"] = True # type: ignore
|
||||
item["embedding_available"] = True
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -198,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@ -210,6 +265,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset created successfully",
|
||||
service_api_ns.models[DatasetDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
@ -253,7 +313,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
return dump_response(DatasetDetailResponse, dataset), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -271,6 +331,11 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -280,7 +345,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
# check embedding setting
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -312,7 +377,13 @@ class DatasetApi(DatasetApiResource):
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
return (
|
||||
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
|
||||
mode="json",
|
||||
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@ -326,6 +397,11 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -376,7 +452,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -389,7 +465,7 @@ class DatasetApi(DatasetApiResource):
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@ -502,7 +578,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
@ -515,14 +591,18 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[KnowledgeTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -534,6 +614,11 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag created successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -543,9 +628,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -558,6 +644,11 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag updated successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -569,9 +660,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@ -656,6 +748,11 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[DatasetBoundTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
@ -663,5 +760,4 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -27,7 +31,13 @@ register_schema_models(
|
||||
DocumentMetadataOperation,
|
||||
MetadataOperationData,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -43,6 +53,9 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
@ -55,7 +68,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return marshal(metadata, dataset_metadata_fields), 201
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@service_api_ns.doc("get_dataset_metadata")
|
||||
@service_api_ns.doc(description="Get all metadata for a dataset")
|
||||
@ -67,13 +80,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all metadata for a dataset."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -89,6 +106,9 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
@ -102,7 +122,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@service_api_ns.doc(description="Delete metadata")
|
||||
@ -114,6 +134,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(204, "Metadata deleted successfully")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Delete metadata."""
|
||||
@ -138,10 +159,15 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -157,9 +183,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Action completed successfully",
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
|
||||
@ -175,7 +199,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -194,7 +218,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Documents metadata updated successfully",
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
service_api_ns.models[DatasetMetadataActionResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
@ -209,4 +233,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
|
||||
@service_api_ns.route("/")
|
||||
class IndexApi(Resource):
|
||||
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
|
||||
def get(self):
|
||||
def get(self) -> dict[str, str]:
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
|
||||
@ -136,7 +136,7 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
@ -59,7 +58,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
Raises:
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
|
||||
@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource):
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.file import file_manager
|
||||
@ -66,6 +67,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -51,6 +52,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
|
||||
@ -55,6 +55,7 @@ from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
@ -145,9 +146,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
conversation = None
|
||||
else:
|
||||
raise
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from graphon.enums import NodeType
|
||||
|
||||
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
@override
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -791,10 +791,25 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
_visited_folder_ids: set[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
Recursively lists all files inside the given folder prefix.
|
||||
``_visited_folder_ids`` tracks folders already expanded so that a
|
||||
self-referencing folder (where the API returns the folder as its own
|
||||
child) cannot cause infinite recursion.
|
||||
"""
|
||||
if _visited_folder_ids is None:
|
||||
_visited_folder_ids = set()
|
||||
|
||||
# Guard: skip folders we have already expanded to prevent infinite
|
||||
# recursion from self-referencing folder entries in the API response.
|
||||
if prefix in _visited_folder_ids:
|
||||
return
|
||||
_visited_folder_ids.add(prefix)
|
||||
|
||||
result_generator = datasource_runtime.online_drive_browse_files(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveBrowseFilesRequest(
|
||||
@ -806,10 +821,14 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
provider_type=datasource_runtime.datasource_provider_type(),
|
||||
)
|
||||
is_truncated = False
|
||||
has_files = False
|
||||
for result in result_generator:
|
||||
for files in result.result:
|
||||
for file in files.files:
|
||||
has_files = True
|
||||
if file.type == "folder":
|
||||
if file.id in _visited_folder_ids:
|
||||
continue
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime,
|
||||
file.id,
|
||||
@ -818,6 +837,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
all_files,
|
||||
datasource_info,
|
||||
None,
|
||||
_visited_folder_ids,
|
||||
)
|
||||
else:
|
||||
all_files.append(
|
||||
@ -830,7 +850,17 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
is_truncated = files.is_truncated
|
||||
next_page_parameters = files.next_page_parameters
|
||||
|
||||
if is_truncated:
|
||||
# Guard: only follow pagination when the API actually returned files.
|
||||
# An empty folder that incorrectly reports ``is_truncated=True`` would
|
||||
# otherwise recurse forever on the same empty page.
|
||||
if is_truncated and has_files:
|
||||
self._get_files_in_folder(
|
||||
datasource_runtime, prefix, bucket, user_id, all_files, datasource_info, next_page_parameters
|
||||
datasource_runtime,
|
||||
prefix,
|
||||
bucket,
|
||||
user_id,
|
||||
all_files,
|
||||
datasource_info,
|
||||
next_page_parameters,
|
||||
_visited_folder_ids,
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -73,6 +76,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -31,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
@override
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
@override
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
@ -62,6 +65,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
@ -78,6 +82,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
@override
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
@ -95,6 +100,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -8,6 +8,7 @@ scope updates that matter to chat applications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
@ -23,9 +24,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
return
|
||||
@ -44,5 +47,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal, Self, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
|
||||
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._paused = True
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
self._paused = False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, override
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
@override
|
||||
def multimodal_send_format(self) -> str:
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
@override
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
@override
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@override
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
storage_key = self._resolve_storage_key(file=file)
|
||||
data = storage.load(storage_key, stream=False)
|
||||
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
raise ValueError(f"file {storage_key} is not a bytes object")
|
||||
return data
|
||||
|
||||
@override
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
query["as_attachment"] = "true"
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
@override
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
@override
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -12,7 +12,7 @@ state.
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
@ -98,12 +98,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
match event:
|
||||
case GraphRunStartedEvent():
|
||||
@ -131,6 +133,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
case NodeRunPauseRequestedEvent():
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return self.icon
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,5 +68,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
@override
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class JavascriptCodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.JAVASCRIPT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
Transform response to dict
|
||||
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
import jinja2
|
||||
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return runner_script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_preload_script(cls) -> str:
|
||||
preload_script = dedent("""
|
||||
import jinja2
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class Python3CodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.PYTHON3
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
|
||||
@ -43,13 +43,16 @@ request_error = httpx.RequestError
|
||||
max_retries_exceeded_error = MaxRetriesExceededError
|
||||
|
||||
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]:
|
||||
"""Build per-scheme proxy transports with the same TLS policy as the SSRF client."""
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTP_URL,
|
||||
verify=verify,
|
||||
),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
verify=verify,
|
||||
),
|
||||
}
|
||||
|
||||
@ -64,7 +67,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client:
|
||||
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
return httpx.Client(
|
||||
mounts=_create_proxy_mounts(),
|
||||
mounts=_create_proxy_mounts(verify=verify),
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
import flask
|
||||
|
||||
@ -15,6 +16,7 @@ class TraceContextFilter(logging.Filter):
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
@ -54,6 +56,7 @@ class IdentityContextFilter(logging.Filter):
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict, override
|
||||
|
||||
import orjson
|
||||
|
||||
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
@override
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
|
||||
@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
@override
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
@override
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
@override
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@ -159,6 +159,7 @@ class ClientSession(
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@override
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
@ -326,6 +327,7 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
@ -351,6 +353,7 @@ class ClientSession(
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
@override
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@ -358,6 +361,7 @@ class ClientSession(
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
@override
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
|
||||
@ -235,10 +235,11 @@ class TokenBufferMemory:
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
match content:
|
||||
case TextPromptMessageContent():
|
||||
inner_msg += f"{content.data}\n"
|
||||
case ImagePromptMessageContent():
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user