Compare commits

..

1 Commits

Author SHA1 Message Date
11c652f146 chore(deps): bump idna from 3.11 to 3.15 in /api
Bumps [idna](https://github.com/kjd/idna) from 3.11 to 3.15.
- [Release notes](https://github.com/kjd/idna/releases)
- [Changelog](https://github.com/kjd/idna/blob/master/HISTORY.md)
- [Commits](https://github.com/kjd/idna/compare/v3.11...v3.15)

---
updated-dependencies:
- dependency-name: idna
  dependency-version: '3.15'
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-05-19 21:44:21 +00:00
1630 changed files with 17442 additions and 31996 deletions

44
.github/CODEOWNERS vendored
View File

@ -92,28 +92,28 @@
/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @WH-2099
# Backend - Trigger/Schedule/Webhook
/api/controllers/trigger/ @CourTeous33
/api/controllers/console/app/workflow_trigger.py @CourTeous33
/api/controllers/console/workspace/trigger_providers.py @CourTeous33
/api/core/trigger/ @CourTeous33
/api/core/app/layers/trigger_post_layer.py @CourTeous33
/api/services/trigger/ @CourTeous33
/api/models/trigger.py @CourTeous33
/api/fields/workflow_trigger_fields.py @CourTeous33
/api/repositories/workflow_trigger_log_repository.py @CourTeous33
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @CourTeous33
/api/libs/schedule_utils.py @CourTeous33
/api/services/workflow/scheduler.py @CourTeous33
/api/schedule/trigger_provider_refresh_task.py @CourTeous33
/api/schedule/workflow_schedule_task.py @CourTeous33
/api/tasks/trigger_processing_tasks.py @CourTeous33
/api/tasks/trigger_subscription_refresh_tasks.py @CourTeous33
/api/tasks/workflow_schedule_tasks.py @CourTeous33
/api/tasks/workflow_cfs_scheduler/ @CourTeous33
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @CourTeous33
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @CourTeous33
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @CourTeous33
/api/events/event_handlers/sync_webhook_when_app_created.py @CourTeous33
/api/controllers/trigger/ @Mairuis
/api/controllers/console/app/workflow_trigger.py @Mairuis
/api/controllers/console/workspace/trigger_providers.py @Mairuis
/api/core/trigger/ @Mairuis
/api/core/app/layers/trigger_post_layer.py @Mairuis
/api/services/trigger/ @Mairuis
/api/models/trigger.py @Mairuis
/api/fields/workflow_trigger_fields.py @Mairuis
/api/repositories/workflow_trigger_log_repository.py @Mairuis
/api/repositories/sqlalchemy_workflow_trigger_log_repository.py @Mairuis
/api/libs/schedule_utils.py @Mairuis
/api/services/workflow/scheduler.py @Mairuis
/api/schedule/trigger_provider_refresh_task.py @Mairuis
/api/schedule/workflow_schedule_task.py @Mairuis
/api/tasks/trigger_processing_tasks.py @Mairuis
/api/tasks/trigger_subscription_refresh_tasks.py @Mairuis
/api/tasks/workflow_schedule_tasks.py @Mairuis
/api/tasks/workflow_cfs_scheduler/ @Mairuis
/api/events/event_handlers/sync_plugin_trigger_when_app_created.py @Mairuis
/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @Mairuis
/api/events/event_handlers/sync_workflow_schedule_when_app_published.py @Mairuis
/api/events/event_handlers/sync_webhook_when_app_created.py @Mairuis
# Backend - Async Workflow
/api/services/async_workflow_service.py @Mairuis

View File

@ -5,11 +5,11 @@ runs:
using: composite
steps:
- name: Setup pnpm
uses: pnpm/action-setup@0e279bb959325dab635dd2c09392533439d90093 # v6.0.8
uses: pnpm/action-setup@8912a9102ac27614460f54aedde9e1e7f9aec20d # v6.0.5
with:
run_install: false
- name: Setup Vite+
uses: voidzero-dev/setup-vp@ca1c46663915d6c1042ae23bd39ab85718bfb0fa # v1.10.0
uses: voidzero-dev/setup-vp@4f5aa3e38c781f1b01e78fb9255527cee8a6efa6 # v1.8.0
with:
node-version-file: .nvmrc
cache: true

View File

@ -195,7 +195,7 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
files: ./coverage.xml
disable_search: true

View File

@ -35,15 +35,15 @@ jobs:
- service_name: "build-api-amd64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}"
file: "api/Dockerfile"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}"
file: "api/Dockerfile"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
- service_name: "build-web-amd64"
@ -117,8 +117,8 @@ jobs:
matrix:
include:
- service_name: "validate-api-amd64"
build_context: "{{defaultContext}}"
file: "api/Dockerfile"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "validate-web-amd64"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"

View File

@ -6,12 +6,6 @@ on:
- "main"
paths:
- api/Dockerfile
- api/Dockerfile.dockerignore
- api/pyproject.toml
- api/uv.lock
- dify-agent/pyproject.toml
- dify-agent/README.md
- dify-agent/src/**
- web/Dockerfile
concurrency:
@ -31,13 +25,13 @@ jobs:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "api/Dockerfile"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: depot-ubuntu-24.04-4
context: "{{defaultContext}}"
file: "api/Dockerfile"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: depot-ubuntu-24.04-4
@ -70,8 +64,8 @@ jobs:
matrix:
include:
- service_name: "api-amd64"
context: "{{defaultContext}}"
file: "api/Dockerfile"
context: "{{defaultContext}}:api"
file: "Dockerfile"
- service_name: "web-amd64"
context: "{{defaultContext}}"
file: "web/Dockerfile"

View File

@ -63,8 +63,8 @@ jobs:
id: render
run: |
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base "$GITHUB_WORKSPACE/base_report.json" \
< "$GITHUB_WORKSPACE/pr_report.json")"
--base base_report.json \
< pr_report.json)"
{
echo "### Pyrefly Type Coverage"

View File

@ -65,9 +65,6 @@ jobs:
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
# Keep fork-PR comments correct while the trusted workflow_run job is
# still using the default-branch renderer, which resolves --base from api/.
cp /tmp/pyrefly_report_base.json api/base_report.json
- name: Save PR number
run: |
@ -80,7 +77,6 @@ jobs:
path: |
pr_report.json
base_report.json
api/base_report.json
pr_number.txt
- name: Comment PR with type coverage

View File

@ -47,10 +47,6 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Response Contract Linter
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
- name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: make type-check-core

View File

@ -158,7 +158,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.context.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@1dc994ee7a008f0ecc866d9ac23ef036b7229f84 # v1.0.127
uses: anthropics/claude-code-action@476e359e6203e73dad705c8b322e333fabbd7416 # v1.0.119
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -39,7 +39,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Run tests
run: vp test run --reporter=blob --reporter=minimal --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
@ -83,7 +83,7 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: web/coverage
flags: web
@ -117,7 +117,7 @@ jobs:
- name: Report coverage
if: ${{ env.CODECOV_TOKEN != '' }}
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 # v6.0.1
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
with:
directory: packages/dify-ui/coverage
flags: dify-ui

2
.gitignore vendored
View File

@ -250,5 +250,5 @@ scripts/stress-test/reports/
# Code Agent Folder
.qoder/*
.context/
.context/*
.eslintcache

View File

@ -75,19 +75,13 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@uv run --project api --dev ruff format ./api
@uv run --project api --dev ruff check --fix ./api
@$(MAKE) api-contract-lint
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
api-contract-lint:
@echo "🔎 Linting Flask response contracts..."
@uv run --project api --dev python api/dev/lint_response_contracts.py
@echo "✅ Response contract lint complete"
type-check:
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@ -197,7 +191,6 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@ -211,4 +204,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all

View File

@ -657,7 +657,6 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
@ -768,7 +767,6 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true

View File

@ -195,7 +195,6 @@ Before opening a PR / submitting:
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise docstrings and comments.
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,

View File

@ -22,11 +22,9 @@ RUN apt-get update \
libmpfr-dev libmpc-dev
# Install Python dependencies (workspace members under providers/vdb/)
COPY api/pyproject.toml api/uv.lock ./
COPY api/providers ./providers
COPY dify-agent/pyproject.toml dify-agent/README.md /app/dify-agent/
COPY dify-agent/src /app/dify-agent/src
# Trust the checked-in lock during image builds; local path sources are copied from the repository context.
COPY pyproject.toml uv.lock ./
COPY providers ./providers
# Trust the checked-in lock during image builds; dev-only path sources live outside the api/ context.
RUN uv sync --frozen --no-dev
# production stage
@ -110,10 +108,10 @@ RUN python -c "import tiktoken; tiktoken.encoding_for_model('gpt2')" \
&& chown -R dify:dify ${TIKTOKEN_CACHE_DIR}
# Copy source code
COPY --chown=dify:dify api /app/api/
COPY --chown=dify:dify . /app/api/
# Prepare entrypoint script
COPY --chown=dify:dify --chmod=755 api/docker/entrypoint.sh /entrypoint.sh
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh /entrypoint.sh
ARG COMMIT_SHA

View File

@ -1,25 +0,0 @@
*
!api/
!api/**
!dify-agent/
!dify-agent/pyproject.toml
!dify-agent/README.md
!dify-agent/src/
!dify-agent/src/**
api/.venv
api/.venv/**
api/.env
api/*.env.*
api/.idea
api/.mypy_cache
api/.ruff_cache
api/storage/generate_files/*
api/storage/privkeys/*
api/storage/tools/*
api/storage/upload_files/*
api/logs
api/*.log*
**/__pycache__
**/*.pyc

View File

@ -30,7 +30,7 @@ from clients.agent_backend.factory import create_agent_backend_run_client
from clients.agent_backend.fake_client import FakeAgentBackendRunClient, FakeAgentBackendScenario
from clients.agent_backend.request_builder import (
AGENT_SOUL_PROMPT_LAYER_ID,
DIFY_EXECUTION_CONTEXT_LAYER_ID,
DIFY_PLUGIN_CONTEXT_LAYER_ID,
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID,
WORKFLOW_USER_PROMPT_LAYER_ID,
AgentBackendModelConfig,
@ -42,7 +42,7 @@ from clients.agent_backend.request_builder import (
__all__ = [
"AGENT_SOUL_PROMPT_LAYER_ID",
"DIFY_EXECUTION_CONTEXT_LAYER_ID",
"DIFY_PLUGIN_CONTEXT_LAYER_ID",
"WORKFLOW_NODE_JOB_PROMPT_LAYER_ID",
"WORKFLOW_USER_PROMPT_LAYER_ID",
"AgentBackendError",

View File

@ -4,9 +4,7 @@ This module is intentionally an adapter, not a wire DTO package. The emitted
object is always ``dify_agent.protocol.CreateRunRequest`` so the Agent backend
protocol has a single owner. API-only context such as Agent Soul vs workflow job
prompt is preserved in layer names and metadata until the dedicated product
schemas land in later phases. Dify-owned execution identifiers are emitted as an
explicit ``dify.execution_context`` layer so the run request stays fully
composition-driven.
schemas land in later phases.
"""
from __future__ import annotations
@ -17,19 +15,18 @@ from agenton.compositor import CompositorSessionSnapshot
from agenton.layers import ExitIntent
from agenton_collections.layers.plain import PLAIN_PROMPT_LAYER_TYPE_ID, PromptLayerConfig
from dify_agent.layers.dify_plugin import (
DIFY_PLUGIN_LAYER_TYPE_ID,
DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
DifyPluginCredentialValue,
DifyPluginLayerConfig,
DifyPluginLLMLayerConfig,
)
from dify_agent.layers.execution_context import (
DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
DifyExecutionContextLayerConfig,
)
from dify_agent.layers.output import DIFY_OUTPUT_LAYER_TYPE_ID, DifyOutputLayerConfig
from dify_agent.protocol import (
DIFY_AGENT_MODEL_LAYER_ID,
DIFY_AGENT_OUTPUT_LAYER_ID,
CreateRunRequest,
ExecutionContext,
LayerExitSignals,
RunComposition,
RunLayerSpec,
@ -40,30 +37,27 @@ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_validator
AGENT_SOUL_PROMPT_LAYER_ID = "agent_soul_prompt"
WORKFLOW_NODE_JOB_PROMPT_LAYER_ID = "workflow_node_job_prompt"
WORKFLOW_USER_PROMPT_LAYER_ID = "workflow_user_prompt"
DIFY_EXECUTION_CONTEXT_LAYER_ID = "execution_context"
DIFY_PLUGIN_CONTEXT_LAYER_ID = "plugin"
class AgentBackendModelConfig(BaseModel):
"""API-side model/plugin selection before it is converted to Dify Agent layers."""
tenant_id: str
plugin_id: str
model_provider: str
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class AgentBackendOutputConfig(BaseModel):
"""API-side structured output declaration for the conventional output layer.
The structured-output tool name is fixed to ``final_output`` inside
``dify_agent.layers.output`` so callers only control the JSON Schema plus
optional description/strictness metadata.
"""
"""API-side structured output declaration for the conventional output layer."""
json_schema: dict[str, JsonValue]
name: str = "final_result"
description: str | None = None
strict: bool | None = None
@ -74,7 +68,7 @@ class AgentBackendWorkflowNodeRunInput(BaseModel):
"""Inputs needed to build the first workflow-node-oriented Agent backend run request."""
model: AgentBackendModelConfig
execution_context: DifyExecutionContextLayerConfig
execution_context: ExecutionContext
workflow_node_job_prompt: str
user_prompt: str
agent_soul_prompt: str | None = None
@ -126,22 +120,24 @@ class AgentBackendRunRequestBuilder:
config=PromptLayerConfig(user=run_input.user_prompt),
),
RunLayerSpec(
name=DIFY_EXECUTION_CONTEXT_LAYER_ID,
type=DIFY_EXECUTION_CONTEXT_LAYER_TYPE_ID,
name=DIFY_PLUGIN_CONTEXT_LAYER_ID,
type=DIFY_PLUGIN_LAYER_TYPE_ID,
metadata=run_input.metadata,
config=run_input.execution_context,
config=DifyPluginLayerConfig(
tenant_id=run_input.model.tenant_id,
plugin_id=run_input.model.plugin_id,
user_id=run_input.model.user_id,
),
),
RunLayerSpec(
name=DIFY_AGENT_MODEL_LAYER_ID,
type=DIFY_PLUGIN_LLM_LAYER_TYPE_ID,
deps={"execution_context": DIFY_EXECUTION_CONTEXT_LAYER_ID},
deps={"plugin": DIFY_PLUGIN_CONTEXT_LAYER_ID},
metadata=run_input.metadata,
config=DifyPluginLLMLayerConfig(
plugin_id=run_input.model.plugin_id,
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
),
]
@ -155,6 +151,7 @@ class AgentBackendRunRequestBuilder:
metadata=run_input.metadata,
config=DifyOutputLayerConfig(
json_schema=run_input.output.json_schema,
name=run_input.output.name,
description=run_input.output.description,
strict=run_input.output.strict,
),
@ -163,6 +160,7 @@ class AgentBackendRunRequestBuilder:
return CreateRunRequest(
composition=RunComposition(layers=layers),
execution_context=run_input.execution_context,
purpose=run_input.purpose,
idempotency_key=run_input.idempotency_key,
metadata=run_input.metadata,

View File

@ -11,7 +11,6 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
@ -21,6 +20,7 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Any, override
from typing import Any
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
@ -25,7 +25,6 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
@ -91,7 +90,6 @@ class DifyConfig(
# Thanks for your concentration and consideration.
@classmethod
@override
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],

View File

@ -1,4 +1,3 @@
from configs.extra.agent_backend_config import AgentBackendConfig
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
@ -6,7 +5,6 @@ from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
AgentBackendConfig,
ArchiveStorageConfig,
NotionConfig,
SentryConfig,

View File

@ -1,23 +0,0 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class AgentBackendConfig(BaseSettings):
"""
Configuration settings for the Agent backend runtime integration.
"""
AGENT_BACKEND_BASE_URL: str | None = Field(
description="Base URL for the Dify Agent backend service.",
default=None,
)
AGENT_BACKEND_USE_FAKE: bool = Field(
description="Use the deterministic in-process fake Agent backend client.",
default=False,
)
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
description="Scenario used by the fake Agent backend client.",
default="success",
)

View File

@ -265,11 +265,6 @@ class PluginConfig(BaseSettings):
default=60 * 60,
)
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching tenant plugin model providers in Redis",
default=60 * 60 * 24,
)
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
description="Maximum allowed size (bytes) for plugin-generated files",
default=50 * 1024 * 1024,

View File

@ -2,7 +2,6 @@ from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic.types import NonNegativeInt
from pydantic_settings import BaseSettings
@ -71,24 +70,6 @@ class RedisPubSubConfig(BaseSettings):
default=600,
)
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
description=(
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
"an SSE client and the response stream actually closing.\n\n"
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
"return promptly while the daemon listener thread cleans itself up on the next poll "
"boundary - safe because the listener holds no critical state and exits within one poll "
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
),
default=2000,
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, override
from typing import Any
from pydantic import Field
from pydantic.fields import FieldInfo
@ -48,7 +48,6 @@ class ApolloSettingsSource(RemoteSettingsSource):
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

View File

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Mapping
from typing import Any, override
from typing import Any
from pydantic.fields import FieldInfo
@ -41,7 +41,6 @@ class NacosSettingsSource(RemoteSettingsSource):
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
field_value = self.remote_configs.get(field_name)
if field_value is None:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, override, runtime_checkable
from typing import Any, Protocol, final, runtime_checkable
from pydantic import BaseModel
@ -133,12 +133,10 @@ class NullAppContext(AppContext):
self._config = config or {}
self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
@ -148,7 +146,6 @@ class NullAppContext(AppContext):
self._extensions[name] = extension
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield

View File

@ -6,7 +6,7 @@ import contextvars
import threading
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final, override
from typing import Any, final
from flask import Flask, current_app, g
@ -30,18 +30,15 @@ class FlaskAppContext(AppContext):
"""
self._flask_app = flask_app
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():

View File

@ -36,24 +36,6 @@ class FileInfo(BaseModel):
size: int
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
decoded_url = urllib.parse.unquote(url)
if isinstance(query_string, bytes):
raw_query = query_string.decode()
else:
raw_query = query_string
if not raw_query:
return decoded_url
if decoded_url.endswith(("?", "&")):
separator = ""
elif urllib.parse.urlsplit(decoded_url).query:
separator = "&"
else:
separator = "?"
return f"{decoded_url}{separator}{raw_query}"
def guess_file_info_from_response(response: httpx.Response):
url = str(response.url)
# Try to extract filename from URL

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -82,7 +80,7 @@ class AgentRosterDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
def get(self, agent_id):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_roster_agent_detail(tenant_id=tenant_id, agent_id=str(agent_id))
@ -91,7 +89,7 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, agent_id: UUID):
def patch(self, agent_id):
account, tenant_id = current_account_with_tenant()
payload = RosterAgentUpdatePayload.model_validate(console_ns.payload or {})
return _agent_roster_service().update_roster_agent(
@ -102,7 +100,7 @@ class AgentRosterDetailApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, agent_id: UUID):
def delete(self, agent_id):
account, tenant_id = current_account_with_tenant()
_agent_roster_service().archive_roster_agent(tenant_id=tenant_id, agent_id=str(agent_id), account_id=account.id)
return "", 204
@ -113,7 +111,7 @@ class AgentRosterVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID):
def get(self, agent_id):
_, tenant_id = current_account_with_tenant()
return {"data": _agent_roster_service().list_agent_versions(tenant_id=tenant_id, agent_id=str(agent_id))}
@ -123,7 +121,7 @@ class AgentRosterVersionDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, agent_id: UUID, version_id: UUID):
def get(self, agent_id, version_id):
_, tenant_id = current_account_with_tenant()
return _agent_roster_service().get_agent_version_detail(
tenant_id=tenant_id,

View File

@ -1,5 +1,4 @@
from datetime import datetime
from uuid import UUID
import flask_restx
from flask_restx import Resource
@ -147,7 +146,7 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
@ -156,7 +155,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@ -165,7 +164,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
return super().post(resource_id)
@ -181,9 +180,9 @@ class AppApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for an app")
@console_ns.doc(params={"resource_id": "App ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
def delete(self, resource_id, api_key_id):
"""Delete an API key for an app"""
return super().delete(str(resource_id), str(api_key_id))
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.APP
resource_model = App
@ -196,7 +195,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
def get(self, resource_id: UUID):
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@ -205,7 +204,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id: UUID):
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""
return super().post(resource_id)
@ -221,9 +220,9 @@ class DatasetApiKeyResource(BaseApiKeyResource):
@console_ns.doc(description="Delete an API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID", "api_key_id": "API key ID"})
@console_ns.response(204, "API key deleted successfully")
def delete(self, resource_id: UUID, api_key_id: UUID):
def delete(self, resource_id, api_key_id):
"""Delete an API key for a dataset"""
return super().delete(str(resource_id), str(api_key_id))
return super().delete(resource_id, api_key_id)
resource_type = ApiTokenType.DATASET
resource_model = Dataset

View File

@ -159,15 +159,13 @@ class AppAnnotationSettingUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_id: UUID, annotation_setting_id: UUID):
annotation_setting_id_str = str(annotation_setting_id)
def post(self, app_id: UUID, annotation_setting_id):
annotation_setting_id = str(annotation_setting_id)
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(
str(app_id), annotation_setting_id_str, setting_args
)
result = AppAnnotationService.update_app_annotation_setting(str(app_id), annotation_setting_id, setting_args)
return result, 200
@ -183,9 +181,9 @@ class AnnotationReplyActionStatusApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
def get(self, app_id: UUID, job_id: UUID, action: str):
job_id_str = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
def get(self, app_id: UUID, job_id, action):
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
@ -193,10 +191,10 @@ class AnnotationReplyActionStatusApi(Resource):
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@console_ns.route("/apps/<uuid:app_id>/annotations")
@ -271,12 +269,12 @@ class AnnotationApi(Resource):
"message": "annotation_ids are required if the parameter is provided.",
}, 400
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return "", 204
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return result, 204
# If no annotation_ids are provided, handle clearing all annotations
else:
AppAnnotationService.clear_all_annotations(str(app_id))
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
@ -337,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")

View File

@ -633,7 +633,7 @@ class AppApi(Resource):
app_service = AppService()
app_service.delete_app(app_model)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/copy")

View File

@ -97,7 +97,7 @@ class AppImportConfirmApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, import_id: str):
def post(self, import_id):
# Check user role first
current_user, _ = current_account_with_tenant()

View File

@ -131,7 +131,7 @@ class CompletionMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
def post(self, app_model, task_id: str):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@ -212,7 +212,7 @@ class ChatMessageStopApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def post(self, app_model, task_id: str):
def post(self, app_model, task_id):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
import sqlalchemy as sa
from flask import abort, request
@ -30,6 +29,9 @@ from fields.conversation_fields import (
from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from fields.conversation_fields import (
ResultResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -75,6 +77,7 @@ register_schema_models(
ConversationMessageDetailResponse,
ConversationWithSummaryPaginationResponse,
ConversationDetailResponse,
ResultResponse,
CompletionConversationQuery,
ChatConversationQuery,
)
@ -134,7 +137,7 @@ class CompletionConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.group_by(Conversation.id)
.distinct()
)
elif args.annotation_status == "not_annotated":
query = (
@ -165,10 +168,10 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
conversation_id_str = str(conversation_id)
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return ConversationMessageDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_completion_conversation")
@ -182,16 +185,16 @@ class CompletionConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=AppMode.COMPLETION)
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
def delete(self, app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
conversation_id = str(conversation_id)
try:
ConversationService.delete(app_model, conversation_id_str, current_user)
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
@ -272,7 +275,7 @@ class ChatConversationApi(Resource):
.join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
.group_by(Conversation.id)
.distinct()
)
case "not_annotated":
query = (
@ -318,10 +321,10 @@ class ChatConversationDetailApi(Resource):
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@edit_permission_required
def get(self, app_model, conversation_id: UUID):
conversation_id_str = str(conversation_id)
def get(self, app_model, conversation_id):
conversation_id = str(conversation_id)
return ConversationDetailResponse.model_validate(
_get_conversation(app_model, conversation_id_str), from_attributes=True
_get_conversation(app_model, conversation_id), from_attributes=True
).model_dump(mode="json")
@console_ns.doc("delete_chat_conversation")
@ -335,16 +338,16 @@ class ChatConversationDetailApi(Resource):
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
@account_initialization_required
@edit_permission_required
def delete(self, app_model, conversation_id: UUID):
def delete(self, app_model, conversation_id):
current_user, _ = current_account_with_tenant()
conversation_id_str = str(conversation_id)
conversation_id = str(conversation_id)
try:
ConversationService.delete(app_model, conversation_id_str, current_user)
ConversationService.delete(app_model, conversation_id, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204
def _get_conversation(app_model, conversation_id):

View File

@ -1,7 +1,6 @@
import json
from datetime import datetime
from typing import Any
from uuid import UUID
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
@ -163,7 +162,7 @@ class AppMCPServerRefreshController(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, server_id: UUID):
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
server = db.session.scalar(
select(AppMCPServer)

View File

@ -1,7 +1,6 @@
import logging
from datetime import datetime
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -337,13 +336,13 @@ class MessageSuggestedQuestionApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
def get(self, app_model, message_id: UUID):
def get(self, app_model, message_id):
current_user, _ = current_account_with_tenant()
message_id_str = str(message_id)
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, message_id=message_id_str, user=current_user, invoke_from=InvokeFrom.DEBUGGER
app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER
)
except MessageNotExistsError:
raise NotFound("Message not found")
@ -417,11 +416,11 @@ class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model, message_id: UUID):
message_id_str = str(message_id)
def get(self, app_model, message_id: str):
message_id = str(message_id)
message = db.session.scalar(
select(Message).where(Message.id == message_id_str, Message.app_id == app_model.id).limit(1)
select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1)
)
if not message:

View File

@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return "", 204
return {"result": "success"}, 204
except Exception as e:
raise BadRequest(str(e))

View File

@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
user_id=current_user.id,
)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
user_id=current_user.id,
)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -2,7 +2,6 @@ import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -346,15 +345,14 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: UUID):
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
return variable
@ -365,7 +363,7 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: UUID):
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
# Local File:
@ -392,11 +390,10 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
new_name = args_model.name
@ -437,15 +434,14 @@ class VariableApi(Resource):
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: UUID):
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
@ -461,7 +457,7 @@ class VariableResetApi(Resource):
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: UUID):
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -472,11 +468,10 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
variable=draft_var_srv.get_variable(variable_id=variable_id),
app_id=app_model.id,
variable_id=variable_id_str,
variable_id=variable_id,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)

View File

@ -1,6 +1,5 @@
from datetime import UTC, datetime, timedelta
from typing import Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -189,7 +188,7 @@ class WorkflowRunExportApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id: str):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)
@ -368,14 +367,14 @@ class WorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id):
"""
Get workflow run detail
"""
run_id_str = str(run_id)
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id_str)
workflow_run = workflow_run_service.get_workflow_run(app_model=app_model, run_id=run_id)
if workflow_run is None:
raise NotFoundError("Workflow run not found")
@ -397,17 +396,17 @@ class WorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, run_id: UUID):
def get(self, app_model: App, run_id):
"""
Get workflow run node execution list
"""
run_id_str = str(run_id)
run_id = str(run_id)
workflow_run_service = WorkflowRunService()
user = cast("Account | EndUser", current_user)
node_executions = workflow_run_service.get_workflow_run_node_executions(
app_model=app_model,
run_id=run_id_str,
run_id=run_id,
user=user,
)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -89,10 +87,10 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@account_initialization_required
@is_admin_or_owner_required
@console_ns.response(204, "Binding deleted successfully")
def delete(self, binding_id: UUID):
def delete(self, binding_id):
# The role of the current user in the table must be admin or owner
_, current_tenant_id = current_account_with_tenant()
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return "", 204
return {"result": "success"}, 204

View File

@ -1,5 +1,4 @@
import logging
from uuid import UUID
import httpx
from flask import current_app, redirect, request
@ -159,15 +158,16 @@ class OAuthDataSourceSync(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider: str, binding_id: UUID):
binding_id_str = str(binding_id)
def get(self, provider, binding_id):
provider = str(provider)
binding_id = str(binding_id)
OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
try:
oauth_provider.sync_data_source(binding_id_str)
oauth_provider.sync_data_source(binding_id)
except httpx.HTTPStatusError as e:
logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text

View File

@ -1,7 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, Literal, cast
from uuid import UUID
from flask import request
from flask_restx import Resource, fields, marshal_with
@ -294,7 +293,7 @@ class DataSourceNotionApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, page_id: UUID, page_type: str):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
@ -307,11 +306,11 @@ class DataSourceNotionApi(Resource):
plugin_id="langgenius/notion_datasource",
)
page_id_str = str(page_id)
page_id = str(page_id)
extractor = NotionExtractor(
notion_workspace_id="",
notion_obj_id=page_id_str,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=current_tenant_id,
@ -368,7 +367,7 @@ class DataSourceNotionDatasetSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -386,7 +385,7 @@ class DataSourceNotionDocumentSyncApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -1,17 +1,15 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from typing import Any, cast
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
@ -32,10 +30,26 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import build_icon_url, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
@ -47,6 +61,58 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
@ -142,165 +208,9 @@ class ConsoleDatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetListItemResponse(DatasetDetailResponse):
partial_member_list: list[str]
class DatasetListResponse(ResponseModel):
data: list[DatasetListItemResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
class DatasetQueryFileInfoResponse(ResponseModel):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetQueryContentResponse(ResponseModel):
content_type: str
content: str
file_info: DatasetQueryFileInfoResponse | None = None
class DatasetQueryDetailResponse(ResponseModel):
id: str
queries: list[DatasetQueryContentResponse]
source: str
source_app_id: str | None
created_by_role: str
created_by: str
created_at: int
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DatasetQueryListResponse(ResponseModel):
data: list[DatasetQueryDetailResponse]
has_more: bool
limit: int
total: int
page: int
class RelatedAppResponse(ResponseModel):
id: str
name: str
description: str
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None
icon: str | None
icon_background: str | None
icon_url: str | None = None
@model_validator(mode="after")
def _set_icon_url(self) -> "RelatedAppResponse":
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
return self
class RelatedAppListResponse(ResponseModel):
data: list[RelatedAppResponse]
total: int
class DocumentStatusResponse(ResponseModel):
id: str
indexing_status: str
processing_started_at: int | None
parsing_completed_at: int | None
cleaning_completed_at: int | None
splitting_completed_at: int | None
completed_at: int | None
paused_at: int | None
error: str | None
stopped_at: int | None
completed_segments: int | None = None
total_segments: int | None = None
@field_validator(
"processing_started_at",
"parsing_completed_at",
"cleaning_completed_at",
"splitting_completed_at",
"completed_at",
"paused_at",
"stopped_at",
mode="before",
)
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentStatusListResponse(ResponseModel):
data: list[DocumentStatusResponse]
class ErrorDocsResponse(DocumentStatusListResponse):
total: int
class IndexingEstimatePreviewItemResponse(ResponseModel):
content: str
child_chunks: list[str] | None = None
summary: str | None = None
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
question: str
answer: str
class IndexingEstimateResponse(ResponseModel):
total_segments: int
preview: list[IndexingEstimatePreviewItemResponse]
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
class RetrievalSettingResponse(ResponseModel):
retrieval_method: list[str]
class PartialMemberListResponse(ResponseModel):
data: list[str]
class AutoDisableLogsResponse(ResponseModel):
document_ids: list[str]
count: int
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
register_response_schema_models(
console_ns,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetQueryListResponse,
IndexingEstimateResponse,
RelatedAppListResponse,
DocumentStatusListResponse,
ErrorDocsResponse,
RetrievalSettingResponse,
PartialMemberListResponse,
AutoDisableLogsResponse,
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -383,8 +293,17 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
class DatasetListApi(Resource):
@console_ns.doc("get_datasets")
@console_ns.doc(description="Get list of datasets")
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
@console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
"ids": "Filter by dataset IDs (list)",
"keyword": "Search keyword",
"tag_ids": "Filter by tag IDs (list)",
"include_all": "Include all datasets (default: false)",
}
)
@console_ns.response(200, "Datasets retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@ -423,7 +342,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
partial_members_map: dict[str, list[str]] = {}
if dataset_ids:
@ -460,12 +379,12 @@ class DatasetListApi(Resource):
"total": total,
"page": query.page,
}
return dump_response(DatasetListResponse, response), 200
return response, 200
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@ -494,7 +413,7 @@ class DatasetListApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return dump_response(DatasetDetailResponse, dataset), 201
return marshal(dataset, dataset_detail_fields), 201
@console_ns.route("/datasets/<uuid:dataset_id>")
@ -502,17 +421,13 @@ class DatasetApi(Resource):
@console_ns.doc("get_dataset")
@console_ns.doc(description="Get dataset details")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Dataset retrieved successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -522,7 +437,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = dump_response(DatasetDetailResponse, dataset)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -555,18 +470,14 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(
200,
"Dataset updated successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID):
def patch(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -595,7 +506,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = dump_response(DatasetDetailResponse, dataset)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_tenant_id
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
@ -614,7 +525,7 @@ class DatasetApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Dataset deleted successfully")
def delete(self, dataset_id: UUID):
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
@ -624,7 +535,7 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return "", 204
return {"result": "success"}, 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@ -644,7 +555,7 @@ class DatasetUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
@ -656,15 +567,11 @@ class DatasetQueryApi(Resource):
@console_ns.doc("get_dataset_queries")
@console_ns.doc(description="Get dataset query history")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Query history retrieved successfully",
console_ns.models[DatasetQueryListResponse.__name__],
)
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -682,24 +589,20 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
"data": dataset_queries,
"data": marshal(dataset_queries, dataset_query_detail_model),
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
}
return dump_response(DatasetQueryListResponse, response), 200
return response, 200
@console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource):
@console_ns.doc("estimate_dataset_indexing")
@console_ns.doc(description="Estimate dataset indexing cost")
@console_ns.response(
200,
"Indexing estimate calculated successfully",
console_ns.models[IndexingEstimateResponse.__name__],
)
@console_ns.response(200, "Indexing estimate calculated successfully")
@setup_required
@login_required
@account_initialization_required
@ -796,15 +699,12 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.doc("get_dataset_related_apps")
@console_ns.doc(description="Get applications related to dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Related apps retrieved successfully",
console_ns.models[RelatedAppListResponse.__name__],
)
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
@marshal_with(related_app_list_model)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -824,7 +724,7 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
return {"data": related_apps, "total": len(related_apps)}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
@ -832,19 +732,15 @@ class DatasetIndexingStatusApi(Resource):
@console_ns.doc("get_dataset_indexing_status")
@console_ns.doc(description="Get dataset indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Indexing status retrieved successfully",
console_ns.models[DocumentStatusListResponse.__name__],
)
@console_ns.response(200, "Indexing status retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@ -882,8 +778,9 @@ class DatasetIndexingStatusApi(Resource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data, 200
@console_ns.route("/datasets/api-keys")
@ -952,15 +849,15 @@ class DatasetApiDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, api_key_id: UUID):
def delete(self, api_key_id):
_, current_tenant_id = current_account_with_tenant()
api_key_id_str = str(api_key_id)
api_key_id = str(api_key_id)
key = db.session.scalar(
select(ApiToken)
.where(
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id_str,
ApiToken.id == api_key_id,
)
.limit(1)
)
@ -976,7 +873,7 @@ class DatasetApiDeleteApi(Resource):
db.session.delete(key)
db.session.commit()
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
@ -985,7 +882,7 @@ class DatasetEnableApiApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, dataset_id: UUID, status: str):
def post(self, dataset_id, status):
dataset_id_str = str(dataset_id)
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
@ -1010,18 +907,13 @@ class DatasetApiBaseUrlApi(Resource):
class DatasetRetrievalSettingApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting")
@console_ns.doc(description="Get dataset retrieval settings")
@console_ns.response(
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
)
@console_ns.response(200, "Retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
)
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -1029,19 +921,12 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting_mock")
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
@console_ns.doc(params={"vector_type": "Vector store type"})
@console_ns.response(
200,
"Mock retrieval settings retrieved successfully",
console_ns.models[RetrievalSettingResponse.__name__],
)
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type: str):
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
)
def get(self, vector_type):
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@ -1049,19 +934,19 @@ class DatasetErrorDocs(Resource):
@console_ns.doc("get_dataset_error_docs")
@console_ns.doc(description="Get dataset error documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
@console_ns.response(200, "Error documents retrieved successfully")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
@ -1069,17 +954,13 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.doc("get_dataset_permission_users")
@console_ns.doc(description="Get dataset permission user list")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Permission users retrieved successfully",
console_ns.models[PartialMemberListResponse.__name__],
)
@console_ns.response(200, "Permission users retrieved successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -1092,7 +973,9 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
return {
"data": partial_members_list,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
@ -1100,18 +983,14 @@ class DatasetAutoDisableLogApi(Resource):
@console_ns.doc("get_dataset_auto_disable_logs")
@console_ns.doc(description="Get dataset auto disable logs")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(
200,
"Auto disable logs retrieved successfully",
console_ns.models[AutoDisableLogsResponse.__name__],
)
@console_ns.response(200, "Auto disable logs retrieved successfully")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200

View File

@ -5,7 +5,6 @@ from collections.abc import Sequence
from contextlib import ExitStack
from datetime import datetime
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
@ -316,9 +315,9 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
@ -343,7 +342,7 @@ class DatasetDocumentListApi(Resource):
)
except (ArgumentTypeError, ValueError, Exception):
fetch = False
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -352,7 +351,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == current_tenant_id)
query = select(Document).where(Document.dataset_id == str(dataset_id), Document.tenant_id == current_tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
@ -373,7 +372,7 @@ class DatasetDocumentListApi(Resource):
sa.select(
DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")
)
.where(DocumentSegment.dataset_id == dataset_id_str)
.where(DocumentSegment.dataset_id == str(dataset_id))
.group_by(DocumentSegment.document_id)
.subquery()
)
@ -445,11 +444,11 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -473,7 +472,7 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -491,9 +490,9 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Documents deleted successfully")
def delete(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
@ -505,7 +504,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/init")
@ -583,11 +582,11 @@ class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
@ -625,7 +624,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
data_process_rule_dict,
document.doc_form,
"English",
dataset_id_str,
dataset_id,
)
return estimate_response.model_dump(), 200
except LLMBadRequestError:
@ -648,10 +647,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
def get(self, dataset_id, batch):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
@ -725,7 +725,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
data_process_rule_dict,
document.doc_form,
"English",
dataset_id_str,
dataset_id,
)
return response.model_dump(), 200
except LLMBadRequestError:
@ -745,9 +745,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
def get(self, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch)
documents_status = []
for document in documents:
completed_segments = (
@ -799,16 +800,16 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -817,7 +818,7 @@ class DocumentIndexingStatusApi(DocumentResource):
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -860,10 +861,10 @@ class DocumentApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
@ -872,7 +873,7 @@ class DocumentApi(DocumentResource):
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
response = {
"id": document.id,
@ -906,7 +907,7 @@ class DocumentApi(DocumentResource):
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
response = {
"id": document.id,
@ -949,23 +950,23 @@ class DocumentApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id, document_id)
try:
DocumentService.delete_document(document)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
@ -979,7 +980,7 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@ -996,16 +997,16 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id: str):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
dataset_id=dataset_id_str,
dataset_id=dataset_id,
document_ids=document_ids,
tenant_id=current_tenant_id,
current_user=current_user,
@ -1043,11 +1044,11 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
def patch(self, dataset_id, document_id, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1091,11 +1092,11 @@ class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id: UUID, document_id: UUID):
def put(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
document = self.get_document(dataset_id, document_id)
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
@ -1140,10 +1141,10 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
def patch(self, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@ -1178,16 +1179,16 @@ class DocumentPauseApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document paused successfully")
def patch(self, dataset_id: UUID, document_id: UUID):
def patch(self, dataset_id, document_id):
"""pause document."""
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
@ -1203,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot pause completed document.")
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
@ -1213,14 +1214,14 @@ class DocumentRecoverApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document resumed successfully")
def patch(self, dataset_id: UUID, document_id: UUID):
def patch(self, dataset_id, document_id):
"""recover document."""
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:
@ -1235,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Document is not in paused status.")
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
@ -1246,11 +1247,11 @@ class DocumentRetryApi(DocumentResource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentRetryPayload.__name__])
@console_ns.response(204, "Documents retry started successfully")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
"""retry document."""
payload = DocumentRetryPayload.model_validate(console_ns.payload or {})
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
retry_documents = []
if not dataset:
raise NotFound("Dataset not found.")
@ -1276,9 +1277,9 @@ class DocumentRetryApi(DocumentResource):
logger.exception("Failed to retry document, document id: %s", document_id)
continue
# retry document
DocumentService.retry_document(dataset_id_str, retry_documents)
DocumentService.retry_document(dataset_id, retry_documents)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
@ -1288,7 +1289,7 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
@ -1300,7 +1301,7 @@ class DocumentRenameApi(DocumentResource):
payload = DocumentRenamePayload.model_validate(console_ns.payload or {})
try:
document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name)
document = DocumentService.rename_document(dataset_id, document_id, payload.name)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
@ -1313,15 +1314,15 @@ class WebsiteDocumentSyncApi(DocumentResource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != current_tenant_id:
@ -1332,7 +1333,7 @@ class WebsiteDocumentSyncApi(DocumentResource):
if DocumentService.check_archived(document):
raise ArchivedDocumentImmutableError()
# sync document
DocumentService.sync_website_document(dataset_id_str, document)
DocumentService.sync_website_document(dataset_id, document)
return {"result": "success"}, 200
@ -1342,19 +1343,19 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = db.session.scalar(
select(DocumentPipelineExecutionLog)
.where(DocumentPipelineExecutionLog.document_id == document_id_str)
.where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.limit(1)
)
@ -1391,7 +1392,7 @@ class DocumentGenerateSummaryApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
"""
Generate summary index for specified documents.
@ -1400,10 +1401,10 @@ class DocumentGenerateSummaryApi(Resource):
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -1437,7 +1438,7 @@ class DocumentGenerateSummaryApi(Resource):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list)
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
@ -1451,7 +1452,7 @@ class DocumentGenerateSummaryApi(Resource):
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id_str,
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
@ -1464,11 +1465,11 @@ class DocumentGenerateSummaryApi(Resource):
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id_str, document.id)
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id_str,
dataset_id,
)
return {"result": "success"}, 200
@ -1484,7 +1485,7 @@ class DocumentSummaryStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
@ -1498,11 +1499,11 @@ class DocumentSummaryStatusApi(DocumentResource):
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -1516,8 +1517,8 @@ class DocumentSummaryStatusApi(DocumentResource):
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id_str,
dataset_id=dataset_id_str,
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200

View File

@ -1,6 +1,4 @@
import uuid
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource, marshal
@ -115,12 +113,12 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
def get(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -129,7 +127,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
@ -150,7 +148,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = (
select(DocumentSegment)
.where(
DocumentSegment.document_id == document_id_str,
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
@ -203,9 +201,7 @@ class DatasetDocumentSegmentListApi(Resource):
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(
segment_ids=segment_ids, dataset_id=dataset_id_str
)
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
@ -230,19 +226,19 @@ class DatasetDocumentSegmentListApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
def delete(self, dataset_id, document_id):
current_user, _ = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
@ -255,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
@ -266,15 +262,15 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
def patch(self, dataset_id, document_id, action):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting
@ -325,17 +321,17 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if not current_user.is_dataset_editor:
@ -365,7 +361,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -376,19 +372,19 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -408,10 +404,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -432,33 +428,33 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def delete(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -471,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return "", 204
return {"result": "success"}, 204
@console_ns.route(
@ -487,17 +483,17 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
def post(self, dataset_id, document_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
@ -521,8 +517,8 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
batch_create_segment_to_index_task.delay(
str(job_id),
upload_file_id,
dataset_id_str,
document_id_str,
dataset_id,
document_id,
current_tenant_id,
current_user.id,
)
@ -534,7 +530,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, job_id=None, dataset_id: UUID | None = None, document_id: UUID | None = None):
def get(self, job_id=None, dataset_id=None, document_id=None):
if job_id is None:
raise NotFound("The job does not exist.")
job_id = str(job_id)
@ -555,24 +551,24 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def post(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -610,26 +606,26 @@ class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def get(self, dataset_id, document_id, segment_id):
_, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -646,9 +642,7 @@ class ChildChunkAddApi(Resource):
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
@ -662,26 +656,26 @@ class ChildChunkAddApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def patch(self, dataset_id, document_id, segment_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
@ -711,39 +705,39 @@ class ChildChunkUpdateApi(Resource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
ChildChunk.document_id == document_id,
)
.limit(1)
)
@ -760,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return "", 204
return {"result": "success"}, 204
@setup_required
@login_required
@ -768,39 +762,39 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
current_user, current_tenant_id = current_account_with_tenant()
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id_str = str(document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id_str = str(segment_id)
segment_id = str(segment_id)
segment = db.session.scalar(
select(DocumentSegment)
.where(DocumentSegment.id == segment_id_str, DocumentSegment.tenant_id == current_tenant_id)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
.limit(1)
)
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id_str = str(child_chunk_id)
child_chunk_id = str(child_chunk_id)
child_chunk = db.session.scalar(
select(ChildChunk)
.where(
ChildChunk.id == str(child_chunk_id_str),
ChildChunk.id == str(child_chunk_id),
ChildChunk.tenant_id == current_tenant_id,
ChildChunk.segment_id == segment.id,
ChildChunk.document_id == document_id_str,
ChildChunk.document_id == document_id,
)
.limit(1)
)

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from flask_restx import Resource, fields, marshal
from pydantic import BaseModel, Field
@ -10,12 +8,7 @@ from controllers.common.fields import UsageCountResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
)
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.dataset_fields import (
dataset_detail_fields,
dataset_retrieval_model_fields,
@ -131,9 +124,9 @@ class ExternalApiTemplateListApi(Resource):
@console_ns.response(200, "External API templates retrieved successfully")
@setup_required
@login_required
@with_current_tenant_id
@account_initialization_required
def get(self, current_tenant_id: str):
def get(self):
_, current_tenant_id = current_account_with_tenant()
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
@ -182,11 +175,11 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id_str, current_tenant_id
external_knowledge_api_id, current_tenant_id
)
if external_knowledge_api is None:
raise NotFound("API template not found.")
@ -197,9 +190,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id: UUID):
def patch(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
@ -207,7 +200,7 @@ class ExternalApiTemplateApi(Resource):
external_knowledge_api = ExternalDatasetService.update_external_knowledge_api(
tenant_id=current_tenant_id,
user_id=current_user.id,
external_knowledge_api_id=external_knowledge_api_id_str,
external_knowledge_api_id=external_knowledge_api_id,
args=payload.model_dump(),
)
@ -217,15 +210,15 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(204, "External knowledge API deleted successfully")
def delete(self, external_knowledge_api_id: UUID):
def delete(self, external_knowledge_api_id):
current_user, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id_str)
return "", 204
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
@ -237,12 +230,12 @@ class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
def get(self, external_knowledge_api_id):
_, current_tenant_id = current_account_with_tenant()
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_id = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
external_knowledge_api_id_str, current_tenant_id
external_knowledge_api_id, current_tenant_id
)
return {"is_using": external_knowledge_api_is_using, "count": count}, 200
@ -293,7 +286,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from flask_restx import Resource
from pydantic import Field, field_validator
@ -119,7 +118,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
def post(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)

View File

@ -1,19 +1,14 @@
from typing import Literal
from uuid import UUID
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from fields.dataset_fields import dataset_metadata_fields
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
@ -27,12 +22,7 @@ from services.metadata_service import MetadataService
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
register_response_schema_models(
console_ns,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -41,9 +31,9 @@ class DatasetMetadataCreateApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id: UUID):
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
metadata_args = MetadataArgs.model_validate(console_ns.payload or {})
@ -54,22 +44,18 @@ class DatasetMetadataCreateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return dump_response(DatasetMetadataResponse, metadata), 201
return metadata, 201
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, dataset_id: UUID):
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
return MetadataService.get_dataset_metadatas(dataset), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -78,9 +64,9 @@ class DatasetMetadataApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@marshal_with(dataset_metadata_fields)
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id: UUID, metadata_id: UUID):
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
payload = MetadataUpdatePayload.model_validate(console_ns.payload or {})
name = payload.name
@ -93,14 +79,14 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return dump_response(DatasetMetadataResponse, metadata), 200
return metadata, 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Metadata deleted successfully")
def delete(self, dataset_id: UUID, metadata_id: UUID):
def delete(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -110,8 +96,7 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
return {"result": "success"}, 204
@console_ns.route("/datasets/metadata/built-in")
@ -120,14 +105,9 @@ class DatasetMetadataBuiltInFieldApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200,
"Built-in fields retrieved successfully",
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
return {"fields": built_in_fields}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -136,8 +116,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id: UUID, action: Literal["enable", "disable"]):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -150,8 +130,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -161,11 +140,8 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
@console_ns.response(
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id: UUID):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -177,5 +153,4 @@ class DocumentMetadataEditApi(Resource):
MetadataService.update_documents_metadata(dataset, metadata_args)
# Frontend callers only await success and invalidate caches; no response body is consumed.
return "", 204
return {"result": "success"}, 200

View File

@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
import services
@ -54,13 +54,12 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
session.commit()
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_tenant_id,

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
@ -169,22 +168,21 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: UUID):
def get(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: UUID):
def patch(self, pipeline: Pipeline, variable_id: str):
# Request payload for file types:
#
# Local File:
@ -212,12 +210,11 @@ class RagPipelineVariableApi(Resource):
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
@ -253,16 +250,15 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: UUID):
def delete(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -271,7 +267,7 @@ class RagPipelineVariableApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: UUID):
def put(self, pipeline: Pipeline, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -282,12 +278,11 @@ class RagPipelineVariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()

View File

@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@ -67,12 +67,10 @@ class RagPipelineImportApi(Resource):
current_user, _ = current_account_with_tenant()
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Use a plain Session so that caught exceptions inside the service
# (which return FAILED status instead of re-raising) do not leave the
# transaction in a closed state that a .begin() context manager cannot
# handle. See app_import.py for the canonical pattern.
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
result = import_service.import_rag_pipeline(
account=account,
@ -82,10 +80,6 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
status = result.status
@ -105,17 +99,15 @@ class RagPipelineImportConfirmApi(Resource):
@account_initialization_required
@edit_permission_required
@marshal_with(pipeline_import_model)
def post(self, import_id: str):
def post(self, import_id):
current_user, _ = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
# Create service with session
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
if result.status == ImportStatus.FAILED:
session.rollback()
else:
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@ -132,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@ -150,7 +142,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with Session(db.engine, expire_on_commit=False) as session:
with sessionmaker(db.engine).begin() as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@ -1,7 +1,6 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource
@ -876,14 +875,14 @@ class RagPipelineWorkflowRunDetailApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: UUID):
def get(self, pipeline: Pipeline, run_id):
"""
Get workflow run detail
"""
run_id_str = str(run_id)
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id_str)
workflow_run = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline=pipeline, run_id=run_id)
if workflow_run is None:
raise NotFound("Workflow run not found")
@ -901,17 +900,17 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: UUID):
def get(self, pipeline: Pipeline, run_id: str):
"""
Get workflow run node execution list
"""
run_id_str = str(run_id)
run_id = str(run_id)
rag_pipeline_service = RagPipelineService()
user = cast("Account | EndUser", current_user)
node_executions = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions(
pipeline=pipeline,
run_id=run_id_str,
run_id=run_id,
user=user,
)
@ -961,15 +960,15 @@ class RagPipelineTransformApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
def post(self, dataset_id: str):
current_user, _ = current_account_with_tenant()
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
dataset_id_str = str(dataset_id)
dataset_id = str(dataset_id)
rag_pipeline_transform_service = RagPipelineTransformService()
result = rag_pipeline_transform_service.transform_dataset(dataset_id_str)
result = rag_pipeline_transform_service.transform_dataset(dataset_id)
return result

View File

@ -133,7 +133,7 @@ class CompletionApi(InstalledAppResource):
)
class CompletionStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app, task_id: str):
def post(self, installed_app, task_id):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
@ -209,7 +209,7 @@ class ChatApi(InstalledAppResource):
)
class ChatStopApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, installed_app, task_id: str):
def post(self, installed_app, task_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
@ -92,7 +91,7 @@ class ConversationListApi(InstalledAppResource):
)
class ConversationApi(InstalledAppResource):
@console_ns.response(204, "Conversation deleted successfully")
def delete(self, installed_app, c_id: UUID):
def delete(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -106,7 +105,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route(
@ -115,7 +114,7 @@ class ConversationApi(InstalledAppResource):
)
class ConversationRenameApi(InstalledAppResource):
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id: UUID):
def post(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -146,7 +145,7 @@ class ConversationRenameApi(InstalledAppResource):
)
class ConversationPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app, c_id: UUID):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -170,7 +169,7 @@ class ConversationPinApi(InstalledAppResource):
)
class ConversationUnPinApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def patch(self, installed_app, c_id: UUID):
def patch(self, installed_app, c_id):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:

View File

@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
db.session.delete(installed_app)
db.session.commit()
return "", 204
return {"result": "success", "message": "App uninstalled successfully"}, 204
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
def patch(self, installed_app):

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from pydantic import BaseModel, TypeAdapter
@ -96,18 +95,18 @@ class MessageListApi(InstalledAppResource):
class MessageFeedbackApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__])
def post(self, installed_app, message_id: UUID):
def post(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id_str = str(message_id)
message_id = str(message_id)
payload = MessageFeedbackPayload.model_validate(console_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id_str,
message_id=message_id,
user=current_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
@ -124,13 +123,13 @@ class MessageFeedbackApi(InstalledAppResource):
)
class MessageMoreLikeThisApi(InstalledAppResource):
@console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__])
def get(self, installed_app, message_id: UUID):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
message_id_str = str(message_id)
message_id = str(message_id)
args = MoreLikeThisQuery.model_validate(request.args.to_dict())
@ -140,7 +139,7 @@ class MessageMoreLikeThisApi(InstalledAppResource):
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=current_user,
message_id=message_id_str,
message_id=message_id,
invoke_from=InvokeFrom.EXPLORE,
streaming=streaming,
)
@ -170,18 +169,18 @@ class MessageMoreLikeThisApi(InstalledAppResource):
)
class MessageSuggestedQuestionApi(InstalledAppResource):
@console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__])
def get(self, installed_app, message_id: UUID):
def get(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id_str = str(message_id)
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id_str, invoke_from=InvokeFrom.EXPLORE
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
@ -9,10 +7,9 @@ from controllers.common.schema import register_response_schema_models, register_
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from models import Account
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -23,8 +20,8 @@ register_response_schema_models(console_ns, ResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
@with_current_user
def get(self, current_user: Account, installed_app):
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -47,8 +44,8 @@ class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
@with_current_user
def post(self, current_user: Account, installed_app):
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -68,15 +65,15 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
@console_ns.response(204, "Saved message deleted successfully")
@with_current_user
def delete(self, current_user: Account, installed_app, message_id: UUID):
def delete(self, installed_app, message_id):
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
message_id_str = str(message_id)
message_id = str(message_id)
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, current_user, message_id_str)
SavedMessageService.delete(app_model, current_user, message_id)
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@ -13,7 +13,6 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -26,7 +25,7 @@ from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from models import Account
from libs.login import current_account_with_tenant
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -42,11 +41,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
def post(self, installed_app: InstalledApp):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()

View File

@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -153,7 +152,7 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, id: UUID):
def get(self, id):
api_based_extension_id = str(id)
_, tenant_id = current_account_with_tenant()
@ -169,7 +168,7 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, id: UUID):
def post(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@ -197,7 +196,7 @@ class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, id: UUID):
def delete(self, id):
api_based_extension_id = str(id)
_, current_tenant_id = current_account_with_tenant()
@ -205,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return "", 204
return {"result": "success"}, 204

View File

@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
@console_ns.route("/features")
@ -28,32 +28,7 @@ class FeatureApi(Resource):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
).model_dump()
payload.pop("vector_space", None)
return payload
@console_ns.route("/features/vector-space")
class FeatureVectorSpaceApi(Resource):
@console_ns.doc("get_tenant_feature_vector_space")
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
@console_ns.response(
200,
"Success",
console_ns.models[LimitationModel.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
"""Get vector-space usage and limit for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_vector_space(current_tenant_id).model_dump()
return FeatureService.get_features(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -107,10 +106,10 @@ class FilePreviewApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, file_id: UUID):
file_id_str = str(file_id)
def get(self, file_id):
file_id = str(file_id)
_, tenant_id = current_account_with_tenant()
text = FileService(db.engine).get_file_preview(file_id_str, tenant_id)
text = FileService(db.engine).get_file_preview(file_id, tenant_id)
return {"content": text}

View File

@ -8,14 +8,8 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_user,
)
from libs.login import login_required
from models import Account
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
@ -76,10 +70,11 @@ class NotificationApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
def get(self, current_user: Account):
def get(self):
current_user, _ = current_account_with_tenant()
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
@ -118,11 +113,11 @@ class NotificationDismissApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self, current_user: Account):
def post(self):
current_user, _ = current_account_with_tenant()
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,

View File

@ -1,5 +1,6 @@
import urllib.parse
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -33,7 +34,7 @@ class GetRemoteFileInfo(Resource):
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
@login_required
def get(self, url: str):
decoded_url = helpers.decode_remote_url(url, request.query_string)
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -132,17 +131,17 @@ class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id: UUID):
def patch(self, tag_id):
current_user, _ = current_account_with_tenant()
tag_id_str = str(tag_id)
tag_id = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagUpdateRequestPayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id_str)
tag = TagService.update_tags(UpdateTagPayload(name=payload.name), tag_id)
binding_count = TagService.get_tag_binding_count(tag_id_str)
binding_count = TagService.get_tag_binding_count(tag_id)
response = TagResponse.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
@ -155,10 +154,10 @@ class TagUpdateDeleteApi(Resource):
@account_initialization_required
@edit_permission_required
@console_ns.response(204, "Tag deleted successfully")
def delete(self, tag_id: UUID):
tag_id_str = str(tag_id)
def delete(self, tag_id):
tag_id = str(tag_id)
TagService.delete_tag(tag_id_str)
TagService.delete_tag(tag_id)
return "", 204

View File

@ -56,12 +56,6 @@ from models.enums import CreatorUserRole
from models.model import UploadFile
from services.account_service import AccountService
from services.billing_service import BillingService
from services.entities.auth_entities import (
ChangeEmailNewEmailToken,
ChangeEmailNewEmailVerifiedToken,
ChangeEmailOldEmailToken,
ChangeEmailOldEmailVerifiedToken,
)
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@ -626,8 +620,8 @@ class ChangeEmailSendEmailApi(Resource):
language = "zh-Hans"
else:
language = "en-US"
account = current_user
user_email = current_user.email
account = None
user_email = None
email_for_sending = args.email.lower()
# Default to the initial phase; any legacy/unexpected client input is
# coerced back to `old_email` so we never trust the caller to declare
@ -642,18 +636,24 @@ class ChangeEmailSendEmailApi(Resource):
if reset_data is None:
raise InvalidTokenError()
if not isinstance(reset_data, ChangeEmailOldEmailVerifiedToken):
# The token used to request a new-email code must come from the
# old-email verification step. This prevents the bypass described
# in GHSA-4q3w-q5mc-45rq where the phase-1 token was reused here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
if not reset_data.is_bound_to_account(current_user.id):
raise InvalidTokenError()
user_email = reset_data.email
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
user_email = current_user.email
else:
if email_for_sending != current_user.email.lower():
raise InvalidEmailError()
email_for_sending = current_user.email
account = AccountService.get_account_by_email_with_case_fallback(args.email)
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account,
@ -674,7 +674,6 @@ class ChangeEmailCheckApi(Resource):
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
@ -687,26 +686,42 @@ class ChangeEmailCheckApi(Resource):
token_data = AccountService.get_change_email_data(args.token)
if token_data is None:
raise InvalidTokenError()
if not token_data.is_bound_to_account(current_user.id):
raise InvalidTokenError()
normalized_token_email = token_data.email.lower()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.code:
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
if isinstance(token_data, ChangeEmailOldEmailToken | ChangeEmailNewEmailToken):
refreshed_token_data = token_data.promote()
else:
# Only advance tokens that were minted by the matching send-code step;
# refuse tokens that have already progressed or lack a phase marker so
# the chain `old_email -> old_email_verified -> new_email -> new_email_verified`
# is strictly enforced.
phase_transitions = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
new_token = AccountService.generate_change_email_token(refreshed_token_data, current_user)
# Refresh token data by generating a new token that carries the
# upgraded phase so later steps can check it.
_, new_token = AccountService.generate_change_email_token(
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@ -731,22 +746,27 @@ class ChangeEmailResetApi(Resource):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
current_user, _ = current_account_with_tenant()
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()
if not reset_data.is_bound_to_account(current_user.id):
raise InvalidTokenError()
if not isinstance(reset_data, ChangeEmailNewEmailVerifiedToken):
# Only tokens that completed both verification phases may be used to
# change the email. This closes GHSA-4q3w-q5mc-45rq where a token from
# the initial send-code step could be replayed directly here.
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
# Bind the new email to the token that was mailed and verified, so a
# verified token cannot be reused with a different `new_email` value.
if reset_data.email.lower() != normalized_new_email:
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
if current_user.email.lower() != reset_data.old_email.lower():
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
# Revoke only after all checks pass so failed attempts don't burn a

View File

@ -1,10 +1,8 @@
from urllib import parse
from uuid import UUID
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
import services
from configs import dify_config
@ -23,15 +21,15 @@ from controllers.console.auth.error import (
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.member_fields import AccountWithRole, AccountWithRoleList
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.account import Account, TenantAccountRole
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@ -80,54 +78,6 @@ def _is_role_enabled(role: TenantAccountRole | str, tenant_id: str) -> bool:
return FeatureService.get_features(tenant_id=tenant_id).dataset_operator_enabled
def _normalize_invitee_emails(emails: list[str]) -> list[str]:
return list(dict.fromkeys(email.lower() for email in emails))
def _count_new_member_invites(tenant_id: str, emails: list[str]) -> int:
new_member_count = 0
for email in emails:
account = AccountService.get_account_by_email_with_case_fallback(email)
if not account:
new_member_count += 1
continue
exists = db.session.scalar(
select(TenantAccountJoin.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == account.id)
.limit(1)
)
if not exists:
new_member_count += 1
return new_member_count
def _count_current_members(tenant_id: str) -> int:
return (
db.session.scalar(select(func.count(TenantAccountJoin.id)).where(TenantAccountJoin.tenant_id == tenant_id)) or 0
)
def _check_member_invite_limits(tenant_id: str, new_member_count: int) -> None:
if new_member_count <= 0:
return
features = FeatureService.get_features(tenant_id=tenant_id)
if dify_config.ENTERPRISE_ENABLED:
workspace_members = features.workspace_members
if workspace_members.enabled is True and not workspace_members.is_available(new_member_count):
raise WorkspaceMembersLimitExceeded()
return
if dify_config.BILLING_ENABLED and features.billing.enabled is True:
members = features.members
current_member_count = _count_current_members(tenant_id)
if 0 < members.limit < current_member_count + new_member_count:
raise WorkspaceMembersLimitExceeded()
@console_ns.route("/workspaces/current/members")
class MemberListApi(Resource):
"""List all members of current tenant."""
@ -154,11 +104,12 @@ class MemberInviteEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("members")
def post(self):
payload = console_ns.payload or {}
args = MemberInvitePayload.model_validate(payload)
invitee_emails = _normalize_invitee_emails(args.emails)
invitee_emails = args.emails
invitee_role = args.role
interface_language = args.language
if not TenantAccountRole.is_non_owner_role(invitee_role):
@ -178,36 +129,37 @@ class MemberInviteEmailApi(Resource):
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL
tenant_id = inviter.current_tenant.id
with redis_client.lock(f"workspace_member_invite:{tenant_id}", timeout=60):
new_member_count = _count_new_member_invites(tenant_id, invitee_emails)
_check_member_invite_limits(tenant_id, new_member_count)
workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members
for invitee_email in invitee_emails:
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
if not workspace_members.is_available(len(invitee_emails)):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",
@ -223,7 +175,7 @@ class MemberCancelInviteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, member_id: UUID):
def delete(self, member_id):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
@ -256,7 +208,7 @@ class MemberUpdateRoleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def put(self, member_id: UUID):
def put(self, member_id):
payload = console_ns.payload or {}
args = MemberRoleUpdatePayload.model_validate(payload)
new_role = args.role
@ -399,7 +351,7 @@ class OwnerTransfer(Resource):
@login_required
@account_initialization_required
@is_allow_transfer_owner
def post(self, member_id: UUID):
def post(self, member_id):
payload = console_ns.payload or {}
args = OwnerTransferPayload.model_validate(payload)

View File

@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")

View File

@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
credential_id=args.credential_id,
)
return "", 204
return {"result": "success"}, 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
@ -532,7 +532,7 @@ class ModelProviderAvailableModelApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, model_type: str):
def get(self, model_type):
_, tenant_id = current_account_with_tenant()
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)

View File

@ -15,7 +15,6 @@ from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.plugin_service import PluginService
from fields.base import ResponseModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
@ -23,6 +22,7 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class ParserList(BaseModel):

View File

@ -4,7 +4,6 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from flask import abort, request
from sqlalchemy import select
@ -17,7 +16,6 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.encryption import FieldEncryption
from libs.login import current_account_with_tenant
from models import Account
from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
@ -84,7 +82,9 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.BILLING_ENABLED:
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -198,11 +198,15 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
utm_info = request.cookies.get("utm_info")
if dify_config.BILLING_ENABLED and utm_info:
_, current_tenant_id = current_account_with_tenant()
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)
@ -305,6 +309,7 @@ def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
@ -322,6 +327,7 @@ def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
from werkzeug.exceptions import Forbidden
from libs.login import current_user
from models import Account
user = current_user._get_current_object()
if not isinstance(user, Account) or not user.is_admin_or_owner:
@ -489,25 +495,3 @@ def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return view(*args, **kwargs)
return decorated
def with_current_tenant_id[T, **P, R](
view: Callable[Concatenate[T, str, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
_, current_tenant_id = current_account_with_tenant()
return view(self, current_tenant_id, *args, **kwargs)
return decorated
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated

View File

@ -1,5 +1,4 @@
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -50,8 +49,8 @@ class ImagePreviewApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID):
file_id_str = str(file_id)
def get(self, file_id):
file_id = str(file_id)
args = FileSignatureQuery.model_validate(request.args.to_dict(flat=True))
timestamp = args.timestamp
@ -60,7 +59,7 @@ class ImagePreviewApi(Resource):
try:
generator, mimetype = FileService(db.engine).get_image_preview(
file_id=file_id_str,
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
sign=sign,
@ -92,14 +91,14 @@ class FilePreviewApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID):
file_id_str = str(file_id)
def get(self, file_id):
file_id = str(file_id)
args = FilePreviewQuery.model_validate(request.args.to_dict(flat=True))
try:
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id_str,
file_id=file_id,
timestamp=args.timestamp,
nonce=args.nonce,
sign=args.sign,
@ -160,10 +159,10 @@ class WorkspaceWebappLogoApi(Resource):
415: "Unsupported file type",
}
)
def get(self, workspace_id: UUID):
workspace_id_str = str(workspace_id)
def get(self, workspace_id):
workspace_id = str(workspace_id)
custom_config = TenantService.get_custom_config(workspace_id_str)
custom_config = TenantService.get_custom_config(workspace_id)
webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None
if not webapp_logo_file_id:

View File

@ -1,5 +1,4 @@
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -46,19 +45,17 @@ class ToolFileApi(Resource):
415: "Unsupported file type",
}
)
def get(self, file_id: UUID, extension: str):
file_id_str = str(file_id)
def get(self, file_id, extension):
file_id = str(file_id)
args = ToolFileQuery.model_validate(request.args.to_dict())
if not verify_tool_file_signature(
file_id=file_id_str, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign
):
if not verify_tool_file_signature(file_id=file_id, timestamp=args.timestamp, nonce=args.nonce, sign=args.sign):
raise Forbidden("Invalid request.")
try:
tool_file_manager = ToolFileManager()
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id_str,
file_id,
)
if not stream or not tool_file:

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -79,10 +78,10 @@ class AnnotationReplyActionStatusApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App, job_id: UUID, action: str):
def get(self, app_model: App, job_id, action):
"""Get the status of an annotation reply action job."""
job_id_str = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{job_id_str}"
job_id = str(job_id)
app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}"
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
@ -90,10 +89,10 @@ class AnnotationReplyActionStatusApi(Resource):
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = f"{action}_app_annotation_error_{job_id_str}"
app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}"
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id_str, "job_status": job_status, "error_msg": error_msg}, 200
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
@service_api_ns.route("/apps/annotations")
@ -174,11 +173,11 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def put(self, app_model: App, annotation_id: UUID):
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -195,7 +194,7 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def delete(self, app_model: App, annotation_id: UUID):
def delete(self, app_model: App, annotation_id: str):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return "", 204

View File

@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -196,7 +195,7 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def delete(self, app_model: App, end_user: EndUser, c_id: UUID):
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -225,7 +224,7 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, c_id: UUID):
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@ -267,7 +266,7 @@ class ConversationVariablesApi(Resource):
service_api_ns.models[ConversationVariableInfiniteScrollPaginationResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, c_id: UUID):
def get(self, app_model: App, end_user: EndUser, c_id):
"""List all variables for a conversation.
Conversational variables are only available for chat applications.
@ -313,7 +312,7 @@ class ConversationVariableDetailApi(Resource):
service_api_ns.models[ConversationVariableResponse.__name__],
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def put(self, app_model: App, end_user: EndUser, c_id: UUID, variable_id: UUID):
def put(self, app_model: App, end_user: EndUser, c_id, variable_id):
"""Update a conversation variable's value.
Allows updating the value of a specific conversation variable.
@ -324,13 +323,13 @@ class ConversationVariableDetailApi(Resource):
raise NotChatAppError()
conversation_id = str(c_id)
variable_id_str = str(variable_id)
variable_id = str(variable_id)
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
try:
variable = ConversationService.update_conversation_variable(
app_model, conversation_id, variable_id_str, end_user, payload.value
app_model, conversation_id, variable_id, end_user, payload.value
)
return ConversationVariableResponse.model_validate(variable, from_attributes=True).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:

View File

@ -1,6 +1,5 @@
import logging
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -51,20 +50,20 @@ class FilePreviewApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
def get(self, app_model: App, end_user: EndUser, file_id: str):
"""
Preview/Download a file that was uploaded via Service API.
Provides secure file preview/download functionality.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
file_id_str = str(file_id)
file_id = str(file_id)
# Parse query parameters
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
# Get file content generator
try:

View File

@ -1,5 +1,4 @@
import logging
from uuid import UUID
from flask import request
from flask_restx import Resource
@ -95,19 +94,19 @@ class MessageFeedbackApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id: UUID):
def post(self, app_model: App, end_user: EndUser, message_id):
"""Submit feedback for a message.
Allows users to rate messages as like/dislike and provide optional feedback content.
"""
message_id_str = str(message_id)
message_id = str(message_id)
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id_str,
message_id=message_id,
user=end_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
@ -160,19 +159,19 @@ class MessageSuggestedApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id: UUID):
def get(self, app_model: App, end_user: EndUser, message_id):
"""Get suggested follow-up questions for a message.
Returns AI-generated follow-up questions based on the message content.
"""
message_id_str = str(message_id)
message_id = str(message_id)
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.SERVICE_API
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API
)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal, override
from typing import Literal
from dateutil.parser import isoparse
from flask import request
@ -76,13 +76,11 @@ def _enum_value(value):
class WorkflowRunStatusField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:

View File

@ -1,18 +1,13 @@
from typing import Any, Literal
from uuid import UUID
from typing import Any, Literal, cast
from flask import request
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -22,10 +17,9 @@ from controllers.service_api.wraps import (
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -125,21 +119,6 @@ class TagUnbindingPayload(BaseModel):
return self
class KnowledgeTagResponse(ResponseModel):
model_config = ConfigDict(coerce_numbers_to_str=True)
id: str
name: str
type: str
# TODO: The public Service API docs expose binding_count as string|null.
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
binding_count: str | None = None
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
pass
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
@ -148,29 +127,6 @@ class DatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
# todo: duplicate code, but the partial_member_list has different nullability
class DatasetListResponse(ResponseModel):
data: list[DatasetDetailResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetBoundTagResponse(ResponseModel):
id: str
name: str
class DatasetBoundTagListResponse(ResponseModel):
data: list[DatasetBoundTagResponse]
total: int
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@ -181,17 +137,9 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
register_response_schema_models(
service_api_ns,
SimpleResultResponse,
KnowledgeTagResponse,
KnowledgeTagListResponse,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetBoundTagListResponse,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
@service_api_ns.route("/datasets")
@ -206,18 +154,9 @@ class DatasetListApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
@service_api_ns.response(
200,
"Datasets retrieved successfully",
service_api_ns.models[DatasetListResponse.__name__],
)
def get(self, tenant_id):
"""Resource for getting datasets."""
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = DatasetListQuery.model_validate(query_params)
query = DatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
datasets, total = DatasetService.get_datasets(
@ -236,17 +175,22 @@ class DatasetListApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
data = marshal(datasets, dataset_detail_fields)
for item in data:
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item_model in model_names:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
else:
item["embedding_available"] = False
item["embedding_available"] = False # type: ignore
else:
item["embedding_available"] = True
item["embedding_available"] = True # type: ignore
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@ -254,7 +198,7 @@ class DatasetListApi(DatasetApiResource):
"total": total,
"page": query.page,
}
return dump_response(DatasetListResponse, response), 200
return response, 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@ -266,11 +210,6 @@ class DatasetListApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200,
"Dataset created successfully",
service_api_ns.models[DatasetDetailResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
@ -314,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return dump_response(DatasetDetailResponse, dataset), 200
return marshal(dataset, dataset_detail_fields), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>")
@ -332,12 +271,7 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset retrieved successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
def get(self, _, dataset_id: UUID):
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -346,7 +280,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = dump_response(DatasetDetailResponse, dataset)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
# check embedding setting
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -378,13 +312,7 @@ class DatasetApi(DatasetApiResource):
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return (
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
mode="json",
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
),
200,
)
return data, 200
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@ -398,13 +326,8 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset updated successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id: UUID):
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -453,7 +376,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = dump_response(DatasetDetailResponse, dataset)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
@ -466,7 +389,7 @@ class DatasetApi(DatasetApiResource):
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
return result_data, 200
@service_api_ns.doc("delete_dataset")
@service_api_ns.doc(description="Delete a dataset")
@ -480,7 +403,7 @@ class DatasetApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id: UUID):
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
@ -535,7 +458,7 @@ class DocumentStatusApi(DatasetApiResource):
400: "Bad request - invalid action",
}
)
def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]):
"""
Batch update document status.
@ -579,7 +502,7 @@ class DocumentStatusApi(DatasetApiResource):
except ValueError as e:
raise InvalidActionError(str(e))
return dump_response(SimpleResultResponse, {"result": "success"}), 200
return {"result": "success"}, 200
@service_api_ns.route("/datasets/tags")
@ -592,18 +515,14 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[KnowledgeTagListResponse.__name__],
)
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
return dump_response(KnowledgeTagListResponse, tags), 200
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@ -615,11 +534,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag created successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@ -629,10 +543,9 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@ -645,11 +558,6 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag updated successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -661,10 +569,9 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@ -749,11 +656,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[DatasetBoundTagListResponse.__name__],
)
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
@ -761,4 +663,5 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
response = {"data": tags_list, "total": len(tags)}
return response, 200

View File

@ -374,7 +374,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
@ -395,6 +395,7 @@ class DocumentAddByFileApi(DatasetApiResource):
args["doc_language"] = "English"
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
@ -585,17 +586,17 @@ class DocumentListApi(DatasetApiResource):
404: "Dataset not found",
}
)
def get(self, tenant_id, dataset_id: UUID):
dataset_id_str = str(dataset_id)
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
query_params = DocumentListQuery.model_validate(request.args.to_dict())
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).where(Document.dataset_id == dataset_id_str, Document.tenant_id == tenant_id)
query = select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == tenant_id)
if query_params.status:
query = DocumentService.apply_display_status_filter(query, query_params.status)
@ -645,7 +646,7 @@ class DocumentBatchDownloadZipApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {})
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
@ -680,17 +681,18 @@ class DocumentIndexingStatusApi(DatasetApiResource):
404: "Dataset or documents not found",
}
)
def get(self, tenant_id, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
def get(self, tenant_id, dataset_id, batch):
dataset_id = str(dataset_id)
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# get documents
documents = DocumentService.get_batch_documents(dataset_id_str, batch)
documents = DocumentService.get_batch_documents(dataset_id, batch)
if not documents:
raise NotFound("Documents not found.")
documents_status = []
@ -755,7 +757,7 @@ class DocumentDownloadApi(DatasetApiResource):
service_api_ns.models[UrlResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def get(self, tenant_id, dataset_id: UUID, document_id: UUID):
def get(self, tenant_id, dataset_id, document_id):
dataset = self.get_dataset(str(dataset_id), str(tenant_id))
document = DocumentService.get_document(dataset.id, str(document_id))
@ -783,13 +785,13 @@ class DocumentApi(DatasetApiResource):
404: "Document not found",
}
)
def get(self, tenant_id, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
def get(self, tenant_id, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = self.get_dataset(dataset_id_str, tenant_id)
dataset = self.get_dataset(dataset_id, tenant_id)
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
@ -806,15 +808,15 @@ class DocumentApi(DatasetApiResource):
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
if has_summary_index and document.need_summary is True:
summary_index_status = SummaryIndexService.get_document_summary_index_status(
document_id=document_id_str,
dataset_id=dataset_id_str,
document_id=document_id,
dataset_id=dataset_id,
tenant_id=tenant_id,
)
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
@ -849,7 +851,7 @@ class DocumentApi(DatasetApiResource):
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id_str)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
@ -916,21 +918,21 @@ class DocumentApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id: UUID, document_id: UUID):
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id_str = str(document_id)
dataset_id_str = str(dataset_id)
document_id = str(document_id)
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
# get dataset info
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
if document is None:

View File

@ -1,5 +1,3 @@
from uuid import UUID
from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.service_api import service_api_ns
@ -22,7 +20,7 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
)
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset.
Tests retrieval performance for the specified dataset.

View File

@ -1,20 +1,15 @@
from typing import Literal
from uuid import UUID
from flask_login import current_user
from flask_restx import marshal
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import (
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -32,13 +27,7 @@ register_schema_models(
DocumentMetadataOperation,
MetadataOperationData,
)
register_response_schema_models(
service_api_ns,
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -54,11 +43,8 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
@ -69,7 +55,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return dump_response(DatasetMetadataResponse, metadata), 201
return marshal(metadata, dataset_metadata_fields), 201
@service_api_ns.doc("get_dataset_metadata")
@service_api_ns.doc(description="Get all metadata for a dataset")
@ -81,17 +67,13 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, tenant_id, dataset_id: UUID):
def get(self, tenant_id, dataset_id):
"""Get all metadata for a dataset."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
return MetadataService.get_dataset_metadatas(dataset), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -107,11 +89,8 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id: UUID, metadata_id: UUID):
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
@ -123,7 +102,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return dump_response(DatasetMetadataResponse, metadata), 200
return marshal(metadata, dataset_metadata_fields), 200
@service_api_ns.doc("delete_dataset_metadata")
@service_api_ns.doc(description="Delete metadata")
@ -135,9 +114,8 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(204, "Metadata deleted successfully")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id: UUID, metadata_id: UUID):
def delete(self, tenant_id, dataset_id, metadata_id):
"""Delete metadata."""
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@ -160,15 +138,10 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Built-in fields retrieved successfully",
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self, tenant_id, dataset_id: UUID):
def get(self, tenant_id, dataset_id):
"""Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields()
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
return {"fields": built_in_fields}, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -184,10 +157,12 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
}
)
@service_api_ns.response(
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
200,
"Action completed successfully",
service_api_ns.models[SimpleResultResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable"]):
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
"""Enable or disable built-in metadata field."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -200,7 +175,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
return {"result": "success"}, 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -219,10 +194,10 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.response(
200,
"Documents metadata updated successfully",
service_api_ns.models[DatasetMetadataActionResponse.__name__],
service_api_ns.models[SimpleResultResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id: UUID):
def post(self, tenant_id, dataset_id):
"""Update metadata for multiple documents."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -234,4 +209,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
MetadataService.update_documents_metadata(dataset, metadata_args)
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
return {"result": "success"}, 200

View File

@ -1,6 +1,5 @@
from collections.abc import Generator
from typing import Any
from uuid import UUID
from flask import request
from pydantic import BaseModel
@ -65,11 +64,10 @@ class DatasourcePluginsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: UUID):
def get(self, tenant_id: str, dataset_id: str):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -79,7 +77,7 @@ class DatasourcePluginsApi(DatasetApiResource):
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
)
return datasource_plugins, 200
@ -111,11 +109,10 @@ class DatasourceNodeRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: UUID, node_id: str):
def post(self, tenant_id: str, dataset_id: str, node_id: str):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -123,7 +120,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
@ -175,11 +172,10 @@ class PipelineRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: UUID):
def post(self, tenant_id: str, dataset_id: str):
"""Resource for running a rag pipeline."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -190,7 +186,7 @@ class PipelineRunApi(DatasetApiResource):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,

View File

@ -1,5 +1,4 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal
@ -108,19 +107,17 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
def post(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
@ -153,10 +150,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
}, 200
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@ -171,21 +165,19 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
def get(self, tenant_id: str, dataset_id: str, document_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset_id_str = str(dataset_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
# check embedding model setting
@ -213,7 +205,7 @@ class SegmentApi(DatasetApiResource):
)
segments, total = SegmentService.get_segments(
document_id=document_id_str,
document_id=document_id,
tenant_id=current_tenant_id,
status_list=args.status,
keyword=args.keyword,
@ -222,7 +214,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"data": _marshal_segments_with_summary(segments, dataset_id),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -248,25 +240,22 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@ -287,20 +276,18 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -319,19 +306,15 @@ class DatasetSegmentApi(DatasetApiResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
"doc_form": document.doc_form,
}, 200
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@ -342,29 +325,26 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@service_api_ns.route(
@ -389,26 +369,23 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -452,26 +429,23 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -487,9 +461,7 @@ class ChildChunkApi(DatasetApiResource):
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
@ -525,38 +497,32 @@ class DatasetChildChunkApi(DatasetApiResource):
)
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id_str)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id_str):
if str(segment.document_id) != str(document_id):
raise NotFound("Document not found.")
child_chunk_id_str = str(child_chunk_id)
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -592,38 +558,32 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# get document
document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id_str):
if str(segment.document_id) != str(document_id):
raise NotFound("Segment not found.")
child_chunk_id_str = str(child_chunk_id)
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
if not child_chunk:
raise NotFound("Child chunk not found.")

View File

@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
@service_api_ns.route("/")
class IndexApi(Resource):
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
def get(self) -> dict[str, str]:
def get(self):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",

View File

@ -140,7 +140,7 @@ class CompletionStopApi(WebApiResource):
}
)
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
def post(self, app_model, end_user, task_id: str):
def post(self, app_model, end_user, task_id):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
@ -226,7 +226,7 @@ class ChatStopApi(WebApiResource):
}
)
@web_ns.response(200, "Success", web_ns.models[SimpleResultResponse.__name__])
def post(self, app_model, end_user, task_id: str):
def post(self, app_model, end_user, task_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -1,5 +1,4 @@
from typing import Literal
from uuid import UUID
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator
@ -127,7 +126,7 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
def delete(self, app_model, end_user, c_id: UUID):
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -137,7 +136,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@web_ns.route("/conversations/<uuid:c_id>/name")
@ -166,7 +165,7 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
def post(self, app_model, end_user, c_id: UUID):
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -204,7 +203,7 @@ class ConversationPinApi(WebApiResource):
}
)
@web_ns.response(200, "Conversation pinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model, end_user, c_id: UUID):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
@ -235,7 +234,7 @@ class ConversationUnPinApi(WebApiResource):
}
)
@web_ns.response(200, "Conversation unpinned successfully", web_ns.models[ResultResponse.__name__])
def patch(self, app_model, end_user, c_id: UUID):
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()

View File

@ -1,6 +1,5 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from pydantic import BaseModel, Field, TypeAdapter
@ -133,15 +132,15 @@ class MessageFeedbackApi(WebApiResource):
}
)
@web_ns.response(200, "Feedback submitted successfully", web_ns.models[ResultResponse.__name__])
def post(self, app_model, end_user, message_id: UUID):
message_id_str = str(message_id)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
payload = MessageFeedbackPayload.model_validate(web_ns.payload or {})
try:
MessageService.create_feedback(
app_model=app_model,
message_id=message_id_str,
message_id=message_id,
user=end_user,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
@ -167,11 +166,11 @@ class MessageMoreLikeThisApi(WebApiResource):
500: "Internal Server Error",
}
)
def get(self, app_model, end_user, message_id: UUID):
def get(self, app_model, end_user, message_id):
if app_model.mode != "completion":
raise NotCompletionAppError()
message_id_str = str(message_id)
message_id = str(message_id)
raw_args = request.args.to_dict()
query = MessageMoreLikeThisQuery.model_validate(raw_args)
@ -182,7 +181,7 @@ class MessageMoreLikeThisApi(WebApiResource):
response = AppGenerateService.generate_more_like_this(
app_model=app_model,
user=end_user,
message_id=message_id_str,
message_id=message_id,
invoke_from=InvokeFrom.WEB_APP,
streaming=streaming,
)
@ -223,16 +222,16 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
def get(self, app_model, end_user, message_id: UUID):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id_str = str(message_id)
message_id = str(message_id)
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=end_user, message_id=message_id_str, invoke_from=InvokeFrom.WEB_APP
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
except MessageNotExistsError:

View File

@ -1,5 +1,6 @@
import urllib.parse
import httpx
from flask import request
from pydantic import BaseModel, Field, HttpUrl
import services
@ -58,7 +59,7 @@ class RemoteFileInfoApi(WebApiResource):
Raises:
HTTPException: If the remote file cannot be accessed
"""
decoded_url = helpers.decode_remote_url(url, request.query_string)
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method

View File

@ -1,5 +1,3 @@
from uuid import UUID
from flask import request
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
@ -106,12 +104,12 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
def delete(self, app_model, end_user, message_id: UUID):
message_id_str = str(message_id)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
if app_model.mode != "completion":
raise NotCompletionAppError()
SavedMessageService.delete(app_model, end_user, message_id_str)
SavedMessageService.delete(app_model, end_user, message_id)
return "", 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@ -22,6 +22,9 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
@ -147,9 +150,44 @@ class BaseAgentRunner(AppRunner):
message_tool = PromptMessageTool(
name=tool.tool_name,
description=tool_entity.entity.description.llm,
parameters=tool_entity.get_llm_parameters_json_schema(),
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
@ -214,7 +252,40 @@ class BaseAgentRunner(AppRunner):
"""
update prompt message tool
"""
prompt_tool.parameters = tool.get_llm_parameters_json_schema()
# try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def create_agent_thought(

View File

@ -1,5 +1,4 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.file import file_manager
@ -67,7 +66,6 @@ class CotChatAgentRunner(CotAgentRunner):
return prompt_messages
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize

View File

@ -1,5 +1,4 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.model_runtime.entities.message_entities import (
@ -52,7 +51,6 @@ class CotCompletionAgentRunner(CotAgentRunner):
return historic_prompt
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Any, override
from typing import Any
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
@ -23,7 +23,6 @@ class PluginAgentStrategy(BaseAgentStrategy):
self.declaration = declaration
self.meta_version = meta_version
@override
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
@ -35,7 +34,6 @@ class PluginAgentStrategy(BaseAgentStrategy):
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
@override
def _invoke(
self,
params: dict[str, Any],

View File

@ -55,7 +55,6 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
@ -146,15 +145,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
except ConversationNotExistsError:
if invoke_from == InvokeFrom.SERVICE_API:
conversation = None
else:
raise
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast, override
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -20,7 +20,6 @@ class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -60,7 +59,6 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -78,7 +76,6 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:
@ -110,7 +107,6 @@ class AdvancedChatAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast, override
from typing import Any, cast
from pydantic import JsonValue
@ -16,7 +16,6 @@ from core.app.entities.task_entities import (
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -38,7 +37,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -56,7 +54,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -88,7 +85,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast, override
from typing import Any, cast
from pydantic import JsonValue
@ -16,7 +16,6 @@ from core.app.entities.task_entities import (
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -38,7 +37,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -56,7 +54,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -88,7 +85,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

Some files were not shown because too many files have changed in this diff Show More