Compare commits

..

40 Commits

Author SHA1 Message Date
yyh
ef2ce4aeb8 test(web): align attribution bootstrap export 2026-05-23 22:28:08 +08:00
yyh
dfa2851be1 chore(web): remove select-auto in body 2026-05-23 22:14:08 +08:00
72ee50c74f refactor: add missing @override decorators to method overrides (#36501)
Co-authored-by: EvanYao826 <evanyao826@gmail.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-23 09:56:36 +00:00
8d99326fb3 feat(plugin): cache plugin model providers by tenant (#36449)
Co-authored-by: WH-2099 <wh2099@pm.me>
2026-05-23 09:12:09 +00:00
2a0c098857 refactor: convert isinstance chains to match/case in otel parser (#36534)
Co-authored-by: Cowork 3P <cowork-3p@localhost>
2026-05-22 18:39:24 +00:00
790ca72627 refactor(api): migrate console/service_api.dataset to BaseModel (#36480) 2026-05-22 17:39:07 +00:00
4d8b6c7dc0 refactor: add missing @override decorator to remaining MCP, Jieba, embeddings, and misc subclasses (#36528) 2026-05-22 13:45:35 +00:00
473c945839 chore: seprate vector space quota query (#36514)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-22 09:26:17 +00:00
yyh
a698c60b29 fix(web): stabilize document picker search focus (#36525) 2026-05-22 09:24:37 +00:00
yyh
24bab5fb2a refactor(web): improve retrieval and tag control semantics (#36521) 2026-05-22 09:01:31 +00:00
yyh
93b7a81071 fix(dify-ui): align form label guidance (#36510) 2026-05-22 07:29:57 +00:00
157e6244dd refactor: add missing @override decorator to agent runners, tool caches, and logging extensions (#36511) 2026-05-22 06:41:48 +00:00
yyh
964aaad7ed refactor: streamline workflow context menu lifecycle (#36500) 2026-05-22 04:31:39 +00:00
92181dbe09 fix(api): preserve remote file URL query params (#36478) 2026-05-22 01:45:20 +00:00
30deef45d9 fix(api): pass SSL verify flag to SSRF proxy mounts (#36455) 2026-05-22 01:31:46 +00:00
ee28074390 refactor: add missing @override decorator to Moderation subclasses (#36492) 2026-05-21 19:42:20 +00:00
1fb491337b refactor: add missing @override decorator to datasource plugin classes (#36494) 2026-05-21 19:41:42 +00:00
82b0a03f5a refactor: add missing @override decorator to PluginModelRuntime (#36493) 2026-05-21 19:40:40 +00:00
6185016910 refactor: add missing @override decorator to file access controller and workflow file runtime (#36495) 2026-05-21 19:39:51 +00:00
b4f5f4869f refactor: add missing @override decorator to code executor providers and transformers (#36496) 2026-05-21 19:39:10 +00:00
7ecbed3b04 chore: add Type to test (#36454)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-21 19:25:03 +00:00
5b58defd62 refactor: add missing @override decorator to GraphEngineLayer subclasses (#36491) 2026-05-21 16:32:02 +00:00
73196de5e1 refactor: add missing @override decorator to AppQueueManager subclasses (#36490) 2026-05-21 16:25:07 +00:00
ea5e487d3c fix(api): stop returning 204 with response body and add CI check (#36489) 2026-05-21 16:20:34 +00:00
f19702f76c feat(api): Flask-RESTX response() vs actual return value checker (#36488) 2026-05-21 15:05:06 +00:00
092c8bca81 refactor(api): migrate console.datasets.metadata to BaseModel (#36450)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-21 15:04:42 +00:00
c50d504c44 refactor: add missing @override decorator to AppGenerateResponseConverter subclasses (#36486) 2026-05-21 14:00:22 +00:00
1b4356b66a fix(ci): bad pyinfra type coverage report comments (#36482) 2026-05-21 12:08:24 +00:00
yyh
7f633622aa fix(web): use popup-open selectors for trigger styles (#36471) 2026-05-21 06:13:11 +00:00
yyh
66f5ab4cfc feat: add dify-ui input primitive (#36446) 2026-05-21 03:15:38 +00:00
0cf9597f52 fix: suggested questions API crash on legacy conversation override configs (#36459) 2026-05-21 01:58:52 +00:00
60cd346fa6 feat: wire workflow agent node runtime (#36437)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 12:39:45 +00:00
56d4d54c16 chore: compatiable conversation is not exists (#33274)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-20 12:37:33 +00:00
yyh
9f9cb4d17e feat(ui): migrate radio to Base UI and update web callsites (#36451)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 12:05:31 +00:00
yyh
7d0d9019d8 chore: upgrade base ui to 1.5.0 (#36442) 2026-05-20 09:58:08 +00:00
d646bcf257 chore: remove unused pyrefly ignore comments in dataset.py (#36443)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 09:14:43 +00:00
e3b45a48eb fix: allow config pubsub join timeout for lower post-run latency (#36438)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
2026-05-20 08:45:51 +00:00
848c15a265 chore: update to only SaaS can view template (#36440) 2026-05-20 08:18:26 +00:00
yyh
be8627233d ci: show web test shard failures (#36436) 2026-05-20 08:03:15 +00:00
1fe8b7fb1d fix(auth): use validity-returned token in ChangePasswordForm reset submit (#36415) 2026-05-20 07:59:09 +00:00
424 changed files with 14102 additions and 6079 deletions

View File

@ -63,8 +63,8 @@ jobs:
id: render
run: |
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
--base "$GITHUB_WORKSPACE/base_report.json" \
< "$GITHUB_WORKSPACE/pr_report.json")"
{
echo "### Pyrefly Type Coverage"

View File

@ -65,6 +65,9 @@ jobs:
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
# Keep fork-PR comments correct while the trusted workflow_run job is
# still using the default-branch renderer, which resolves --base from api/.
cp /tmp/pyrefly_report_base.json api/base_report.json
- name: Save PR number
run: |
@ -77,6 +80,7 @@ jobs:
path: |
pr_report.json
base_report.json
api/base_report.json
pr_number.txt
- name: Comment PR with type coverage

View File

@ -47,6 +47,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Response Contract Linter
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
- name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: make type-check-core

View File

@ -39,7 +39,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Run tests
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
run: vp test run --reporter=blob --reporter=minimal --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}

View File

@ -75,13 +75,19 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
@uv run --project api --dev ruff format ./api
@uv run --project api --dev ruff check --fix ./api
@$(MAKE) api-contract-lint
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
api-contract-lint:
@echo "🔎 Linting Flask response contracts..."
@uv run --project api --dev python api/dev/lint_response_contracts.py
@echo "✅ Response contract lint complete"
type-check:
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@ -191,6 +197,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@ -204,4 +211,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all

View File

@ -657,6 +657,7 @@ PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
PLUGIN_MODEL_PROVIDERS_CACHE_TTL=86400
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
@ -767,6 +768,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true

View File

@ -195,6 +195,7 @@ Before opening a PR / submitting:
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise docstrings and comments.
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,

View File

@ -49,6 +49,7 @@ class AgentBackendModelConfig(BaseModel):
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@ -138,6 +139,7 @@ class AgentBackendRunRequestBuilder:
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
),
]

View File

@ -11,6 +11,7 @@ from configs import dify_config
from core.helper import encrypter
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.impl.plugin import PluginInstaller
from core.plugin.plugin_service import PluginService
from core.tools.utils.system_encryption import encrypt_system_params
from extensions.ext_database import db
from models import Tenant
@ -20,7 +21,6 @@ from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from models.tools import ToolOAuthSystemClient
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)

View File

@ -1,3 +1,4 @@
from configs.extra.agent_backend_config import AgentBackendConfig
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
@ -5,6 +6,7 @@ from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
AgentBackendConfig,
ArchiveStorageConfig,
NotionConfig,
SentryConfig,

View File

@ -0,0 +1,23 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class AgentBackendConfig(BaseSettings):
"""
Configuration settings for the Agent backend runtime integration.
"""
AGENT_BACKEND_BASE_URL: str | None = Field(
description="Base URL for the Dify Agent backend service.",
default=None,
)
AGENT_BACKEND_USE_FAKE: bool = Field(
description="Use the deterministic in-process fake Agent backend client.",
default=False,
)
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
description="Scenario used by the fake Agent backend client.",
default="success",
)

View File

@ -265,6 +265,11 @@ class PluginConfig(BaseSettings):
default=60 * 60,
)
PLUGIN_MODEL_PROVIDERS_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching tenant plugin model providers in Redis",
default=60 * 60 * 24,
)
PLUGIN_MAX_FILE_SIZE: PositiveInt = Field(
description="Maximum allowed size (bytes) for plugin-generated files",
default=50 * 1024 * 1024,

View File

@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic.types import NonNegativeInt
from pydantic_settings import BaseSettings
@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
default=600,
)
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
description=(
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
"an SSE client and the response stream actually closing.\n\n"
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
"return promptly while the daemon listener thread cleans itself up on the next poll "
"boundary - safe because the listener holds no critical state and exits within one poll "
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
),
default=2000,
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from pydantic import Field
from pydantic.fields import FieldInfo
@ -48,6 +48,7 @@ class ApolloSettingsSource(RemoteSettingsSource):
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")

View File

@ -1,7 +1,7 @@
import logging
import os
from collections.abc import Mapping
from typing import Any
from typing import Any, override
from pydantic.fields import FieldInfo
@ -41,6 +41,7 @@ class NacosSettingsSource(RemoteSettingsSource):
except Exception as e:
raise RuntimeError(f"Failed to parse config: {e}")
@override
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
field_value = self.remote_configs.get(field_name)
if field_value is None:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
from typing import Any, Protocol, final, override, runtime_checkable
from pydantic import BaseModel
@ -133,10 +133,12 @@ class NullAppContext(AppContext):
self._config = config or {}
self._extensions: dict[str, Any] = {}
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
@ -146,6 +148,7 @@ class NullAppContext(AppContext):
self._extensions[name] = extension
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield

View File

@ -6,7 +6,7 @@ import contextvars
import threading
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final
from typing import Any, final, override
from flask import Flask, current_app, g
@ -30,15 +30,18 @@ class FlaskAppContext(AppContext):
"""
self._flask_app = flask_app
@override
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
@override
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
@override
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():

View File

@ -36,6 +36,24 @@ class FileInfo(BaseModel):
size: int
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
decoded_url = urllib.parse.unquote(url)
if isinstance(query_string, bytes):
raw_query = query_string.decode()
else:
raw_query = query_string
if not raw_query:
return decoded_url
if decoded_url.endswith(("?", "&")):
separator = ""
elif urllib.parse.urlsplit(decoded_url).query:
separator = "&"
else:
separator = "?"
return f"{decoded_url}{separator}{raw_query}"
def guess_file_info_from_response(response: httpx.Response):
url = str(response.url)
# Try to extract filename from URL

View File

@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")

View File

@ -269,12 +269,12 @@ class AnnotationApi(Resource):
"message": "annotation_ids are required if the parameter is provided.",
}, 400
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return result, 204
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return "", 204
# If no annotation_ids are provided, handle clearing all annotations
else:
AppAnnotationService.clear_all_annotations(str(app_id))
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")

View File

@ -633,7 +633,7 @@ class AppApi(Resource):
app_service = AppService()
app_service.delete_app(app_model)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/copy")

View File

@ -29,9 +29,6 @@ from fields.conversation_fields import (
from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from fields.conversation_fields import (
ResultResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -77,7 +74,6 @@ register_schema_models(
ConversationMessageDetailResponse,
ConversationWithSummaryPaginationResponse,
ConversationDetailResponse,
ResultResponse,
CompletionConversationQuery,
ChatConversationQuery,
)
@ -194,7 +190,7 @@ class CompletionConversationDetailApi(Resource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
@ -347,7 +343,7 @@ class ChatConversationDetailApi(Resource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
def _get_conversation(app_model, conversation_id):

View File

@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204
return "", 204
except Exception as e:
raise BadRequest(str(e))

View File

@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
user_id=current_user.id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
user_id=current_user.id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204
return "", 204

View File

@ -1,15 +1,16 @@
from typing import Any, cast
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
@ -30,26 +31,10 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import build_icon_url, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
@ -61,58 +46,6 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
@ -208,9 +141,165 @@ class ConsoleDatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetListItemResponse(DatasetDetailResponse):
partial_member_list: list[str]
class DatasetListResponse(ResponseModel):
data: list[DatasetListItemResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
class DatasetQueryFileInfoResponse(ResponseModel):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetQueryContentResponse(ResponseModel):
content_type: str
content: str
file_info: DatasetQueryFileInfoResponse | None = None
class DatasetQueryDetailResponse(ResponseModel):
id: str
queries: list[DatasetQueryContentResponse]
source: str
source_app_id: str | None
created_by_role: str
created_by: str
created_at: int
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DatasetQueryListResponse(ResponseModel):
data: list[DatasetQueryDetailResponse]
has_more: bool
limit: int
total: int
page: int
class RelatedAppResponse(ResponseModel):
id: str
name: str
description: str
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None
icon: str | None
icon_background: str | None
icon_url: str | None = None
@model_validator(mode="after")
def _set_icon_url(self) -> "RelatedAppResponse":
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
return self
class RelatedAppListResponse(ResponseModel):
data: list[RelatedAppResponse]
total: int
class DocumentStatusResponse(ResponseModel):
id: str
indexing_status: str
processing_started_at: int | None
parsing_completed_at: int | None
cleaning_completed_at: int | None
splitting_completed_at: int | None
completed_at: int | None
paused_at: int | None
error: str | None
stopped_at: int | None
completed_segments: int | None = None
total_segments: int | None = None
@field_validator(
"processing_started_at",
"parsing_completed_at",
"cleaning_completed_at",
"splitting_completed_at",
"completed_at",
"paused_at",
"stopped_at",
mode="before",
)
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentStatusListResponse(ResponseModel):
data: list[DocumentStatusResponse]
class ErrorDocsResponse(DocumentStatusListResponse):
total: int
class IndexingEstimatePreviewItemResponse(ResponseModel):
content: str
child_chunks: list[str] | None = None
summary: str | None = None
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
question: str
answer: str
class IndexingEstimateResponse(ResponseModel):
total_segments: int
preview: list[IndexingEstimatePreviewItemResponse]
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
class RetrievalSettingResponse(ResponseModel):
retrieval_method: list[str]
class PartialMemberListResponse(ResponseModel):
data: list[str]
class AutoDisableLogsResponse(ResponseModel):
document_ids: list[str]
count: int
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
register_response_schema_models(
console_ns,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetQueryListResponse,
IndexingEstimateResponse,
RelatedAppListResponse,
DocumentStatusListResponse,
ErrorDocsResponse,
RetrievalSettingResponse,
PartialMemberListResponse,
AutoDisableLogsResponse,
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -293,17 +382,8 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
class DatasetListApi(Resource):
@console_ns.doc("get_datasets")
@console_ns.doc(description="Get list of datasets")
@console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
"ids": "Filter by dataset IDs (list)",
"keyword": "Search keyword",
"tag_ids": "Filter by tag IDs (list)",
"include_all": "Include all datasets (default: false)",
}
)
@console_ns.response(200, "Datasets retrieved successfully")
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -342,7 +422,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
partial_members_map: dict[str, list[str]] = {}
if dataset_ids:
@ -379,12 +459,12 @@ class DatasetListApi(Resource):
"total": total,
"page": query.page,
}
return response, 200
return dump_response(DatasetListResponse, response), 200
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@ -413,7 +493,7 @@ class DatasetListApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 201
return dump_response(DatasetDetailResponse, dataset), 201
@console_ns.route("/datasets/<uuid:dataset_id>")
@ -421,7 +501,11 @@ class DatasetApi(Resource):
@console_ns.doc("get_dataset")
@console_ns.doc(description="Get dataset details")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
@console_ns.response(
200,
"Dataset retrieved successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -437,7 +521,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
data = dump_response(DatasetDetailResponse, dataset)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -470,7 +554,11 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(
200,
"Dataset updated successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -506,7 +594,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
result_data = dump_response(DatasetDetailResponse, dataset)
tenant_id = current_tenant_id
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
@ -535,7 +623,7 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
return "", 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@ -567,7 +655,11 @@ class DatasetQueryApi(Resource):
@console_ns.doc("get_dataset_queries")
@console_ns.doc(description="Get dataset query history")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
@console_ns.response(
200,
"Query history retrieved successfully",
console_ns.models[DatasetQueryListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -589,20 +681,24 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
"data": marshal(dataset_queries, dataset_query_detail_model),
"data": dataset_queries,
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
return dump_response(DatasetQueryListResponse, response), 200
@console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource):
@console_ns.doc("estimate_dataset_indexing")
@console_ns.doc(description="Estimate dataset indexing cost")
@console_ns.response(200, "Indexing estimate calculated successfully")
@console_ns.response(
200,
"Indexing estimate calculated successfully",
console_ns.models[IndexingEstimateResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -699,11 +795,14 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.doc("get_dataset_related_apps")
@console_ns.doc(description="Get applications related to dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
@console_ns.response(
200,
"Related apps retrieved successfully",
console_ns.models[RelatedAppListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@marshal_with(related_app_list_model)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -724,7 +823,7 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return {"data": related_apps, "total": len(related_apps)}, 200
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
@ -732,7 +831,11 @@ class DatasetIndexingStatusApi(Resource):
@console_ns.doc("get_dataset_indexing_status")
@console_ns.doc(description="Get dataset indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Indexing status retrieved successfully")
@console_ns.response(
200,
"Indexing status retrieved successfully",
console_ns.models[DocumentStatusListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -778,9 +881,8 @@ class DatasetIndexingStatusApi(Resource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data, 200
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
@console_ns.route("/datasets/api-keys")
@ -873,7 +975,7 @@ class DatasetApiDeleteApi(Resource):
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
@ -907,13 +1009,18 @@ class DatasetApiBaseUrlApi(Resource):
class DatasetRetrievalSettingApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting")
@console_ns.doc(description="Get dataset retrieval settings")
@console_ns.response(200, "Retrieval settings retrieved successfully")
@console_ns.response(
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -921,12 +1028,19 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting_mock")
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
@console_ns.doc(params={"vector_type": "Vector store type"})
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
@console_ns.response(
200,
"Mock retrieval settings retrieved successfully",
console_ns.models[RetrievalSettingResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@ -934,7 +1048,7 @@ class DatasetErrorDocs(Resource):
@console_ns.doc("get_dataset_error_docs")
@console_ns.doc(description="Get dataset error documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Error documents retrieved successfully")
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@ -946,7 +1060,7 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
@ -954,7 +1068,11 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.doc("get_dataset_permission_users")
@console_ns.doc(description="Get dataset permission user list")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Permission users retrieved successfully")
@console_ns.response(
200,
"Permission users retrieved successfully",
console_ns.models[PartialMemberListResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -973,9 +1091,7 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
"data": partial_members_list,
}, 200
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
@ -983,7 +1099,11 @@ class DatasetAutoDisableLogApi(Resource):
@console_ns.doc("get_dataset_auto_disable_logs")
@console_ns.doc(description="Get dataset auto disable logs")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Auto disable logs retrieved successfully")
@console_ns.response(
200,
"Auto disable logs retrieved successfully",
console_ns.models[AutoDisableLogsResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@ -993,4 +1113,4 @@ class DatasetAutoDisableLogApi(Resource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200

View File

@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/init")
@ -966,7 +966,7 @@ class DocumentApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot pause completed document.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Document is not in paused status.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource):
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")

View File

@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 204
return "", 204
@console_ns.route(
@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 204
return "", 204
@setup_required
@login_required

View File

@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")

View File

@ -1,14 +1,18 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
@ -22,7 +26,12 @@ from services.metadata_service import MetadataService
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
register_response_schema_models(console_ns, SimpleResultResponse)
register_response_schema_models(
console_ns,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -31,7 +40,7 @@ class DatasetMetadataCreateApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
@ -44,18 +53,22 @@ class DatasetMetadataCreateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return metadata, 201
return dump_response(DatasetMetadataResponse, metadata), 201
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -64,7 +77,7 @@ class DatasetMetadataApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
@ -79,7 +92,7 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200
return dump_response(DatasetMetadataResponse, metadata), 200
@setup_required
@login_required
@ -96,7 +109,8 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return {"result": "success"}, 204
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
@console_ns.route("/datasets/metadata/built-in")
@ -105,9 +119,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200,
"Built-in fields retrieved successfully",
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -116,7 +135,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -130,7 +149,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -140,7 +160,10 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -153,4 +176,5 @@ class DocumentMetadataEditApi(Resource):
MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200
# Frontend callers only await success and invalidate caches; no response body is consumed.
return "", 204

View File

@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@console_ns.route(

View File

@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
db.session.delete(installed_app)
db.session.commit()
return {"result": "success", "message": "App uninstalled successfully"}, 204
return "", 204
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
def patch(self, installed_app):

View File

@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204

View File

@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}, 204
return "", 204

View File

@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
@console_ns.route("/features")
@ -28,7 +28,32 @@ class FeatureApi(Resource):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
).model_dump()
payload.pop("vector_space", None)
return payload
@console_ns.route("/features/vector-space")
class FeatureVectorSpaceApi(Resource):
@console_ns.doc("get_tenant_feature_vector_space")
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
@console_ns.response(
200,
"Success",
console_ns.models[LimitationModel.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
"""Get vector-space usage and limit for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_vector_space(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -1,6 +1,5 @@
import urllib.parse
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -34,7 +33,7 @@ class GetRemoteFileInfo(Resource):
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
@login_required
def get(self, url: str):
decoded_url = urllib.parse.unquote(url)
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)

View File

@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")

View File

@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
credential_id=args.credential_id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")

View File

@ -15,6 +15,7 @@ from controllers.console import console_ns
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.plugin_service import PluginService
from fields.base import ResponseModel
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
@ -22,7 +23,6 @@ from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermissi
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
class ParserList(BaseModel):

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from typing import Literal, override
from dateutil.parser import isoparse
from flask import request
@ -76,11 +76,13 @@ def _enum_value(value):
class WorkflowRunStatusField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:

View File

@ -1,13 +1,17 @@
from typing import Any, Literal, cast
from typing import Any, Literal
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -17,9 +21,10 @@ from controllers.service_api.wraps import (
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -119,6 +124,21 @@ class TagUnbindingPayload(BaseModel):
return self
class KnowledgeTagResponse(ResponseModel):
model_config = ConfigDict(coerce_numbers_to_str=True)
id: str
name: str
type: str
# TODO: The public Service API docs expose binding_count as string|null.
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
binding_count: str | None = None
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
pass
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
@ -127,6 +147,29 @@ class DatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
# todo: duplicate code, but the partial_member_list has different nullability
class DatasetListResponse(ResponseModel):
data: list[DatasetDetailResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetBoundTagResponse(ResponseModel):
id: str
name: str
class DatasetBoundTagListResponse(ResponseModel):
data: list[DatasetBoundTagResponse]
total: int
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@ -137,9 +180,17 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
register_response_schema_models(
service_api_ns,
SimpleResultResponse,
KnowledgeTagResponse,
KnowledgeTagListResponse,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetBoundTagListResponse,
)
@service_api_ns.route("/datasets")
@ -154,9 +205,18 @@ class DatasetListApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
@service_api_ns.response(
200,
"Datasets retrieved successfully",
service_api_ns.models[DatasetListResponse.__name__],
)
def get(self, tenant_id):
"""Resource for getting datasets."""
query = DatasetListQuery.model_validate(request.args.to_dict())
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = DatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor")
datasets, total = DatasetService.get_datasets(
@ -175,22 +235,17 @@ class DatasetListApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
for item in data:
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
else:
item["embedding_available"] = False # type: ignore
item["embedding_available"] = False
else:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@ -198,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
"total": total,
"page": query.page,
}
return response, 200
return dump_response(DatasetListResponse, response), 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@ -210,6 +265,11 @@ class DatasetListApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200,
"Dataset created successfully",
service_api_ns.models[DatasetDetailResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
@ -253,7 +313,7 @@ class DatasetListApi(DatasetApiResource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
return dump_response(DatasetDetailResponse, dataset), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>")
@ -271,6 +331,11 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset retrieved successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -280,7 +345,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
data = dump_response(DatasetDetailResponse, dataset)
# check embedding setting
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -312,7 +377,13 @@ class DatasetApi(DatasetApiResource):
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
return (
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
mode="json",
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
),
200,
)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@ -326,6 +397,11 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset updated successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
@ -376,7 +452,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
result_data = dump_response(DatasetDetailResponse, dataset)
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
@ -389,7 +465,7 @@ class DatasetApi(DatasetApiResource):
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
@service_api_ns.doc("delete_dataset")
@service_api_ns.doc(description="Delete a dataset")
@ -502,7 +578,7 @@ class DocumentStatusApi(DatasetApiResource):
except ValueError as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
return dump_response(SimpleResultResponse, {"result": "success"}), 200
@service_api_ns.route("/datasets/tags")
@ -515,14 +591,18 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[KnowledgeTagListResponse.__name__],
)
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
return dump_response(KnowledgeTagListResponse, tags), 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@ -534,6 +614,11 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag created successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@ -543,9 +628,10 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
)
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@ -558,6 +644,11 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag updated successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -569,9 +660,10 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
)
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@ -656,6 +748,11 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[DatasetBoundTagListResponse.__name__],
)
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
@ -663,5 +760,4 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)}
return response, 200
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200

View File

@ -1,15 +1,19 @@
from typing import Literal
from flask_login import current_user
from flask_restx import marshal
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
from fields.dataset_fields import (
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -27,7 +31,13 @@ register_schema_models(
DocumentMetadataOperation,
MetadataOperationData,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
register_response_schema_models(
service_api_ns,
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -43,6 +53,9 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
@ -55,7 +68,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201
return dump_response(DatasetMetadataResponse, metadata), 201
@service_api_ns.doc("get_dataset_metadata")
@service_api_ns.doc(description="Get all metadata for a dataset")
@ -67,13 +80,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, tenant_id, dataset_id):
"""Get all metadata for a dataset."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -89,6 +106,9 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
@ -102,7 +122,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
return dump_response(DatasetMetadataResponse, metadata), 200
@service_api_ns.doc("delete_dataset_metadata")
@service_api_ns.doc(description="Delete metadata")
@ -114,6 +134,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(204, "Metadata deleted successfully")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id):
"""Delete metadata."""
@ -138,10 +159,15 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Built-in fields retrieved successfully",
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self, tenant_id, dataset_id):
"""Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -157,9 +183,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
}
)
@service_api_ns.response(
200,
"Action completed successfully",
service_api_ns.models[SimpleResultResponse.__name__],
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
@ -175,7 +199,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -194,7 +218,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.response(
200,
"Documents metadata updated successfully",
service_api_ns.models[SimpleResultResponse.__name__],
service_api_ns.models[DatasetMetadataActionResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
@ -209,4 +233,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200

View File

@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
@service_api_ns.route("/")
class IndexApi(Resource):
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
def get(self):
def get(self) -> dict[str, str]:
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",

View File

@ -136,7 +136,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@web_ns.route("/conversations/<uuid:c_id>/name")

View File

@ -1,6 +1,5 @@
import urllib.parse
import httpx
from flask import request
from pydantic import BaseModel, Field, HttpUrl
import services
@ -59,7 +58,7 @@ class RemoteFileInfoApi(WebApiResource):
Raises:
HTTPException: If the remote file cannot be accessed
"""
decoded_url = urllib.parse.unquote(url)
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method

View File

@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204

View File

@ -1,4 +1,5 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.file import file_manager
@ -66,6 +67,7 @@ class CotChatAgentRunner(CotAgentRunner):
return prompt_messages
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize

View File

@ -1,4 +1,5 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.model_runtime.entities.message_entities import (
@ -51,6 +52,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
return historic_prompt
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Any
from typing import Any, override
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
self.declaration = declaration
self.meta_version = meta_version
@override
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
@override
def _invoke(
self,
params: dict[str, Any],

View File

@ -55,6 +55,7 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
@ -145,9 +146,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
except ConversationNotExistsError:
if invoke_from == InvokeFrom.SERVICE_API:
conversation = None
else:
raise
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
from typing import Any, Protocol, override
from graphon.enums import NodeType
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
class NoopDraftVariableSaver(DraftVariableSaver):
@override
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
return None

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._app_mode = app_mode
self._message_id = str(message_id)
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
return dict(blocking_response.model_dump())
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
return cls.convert_blocking_full_response(blocking_response)
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -73,6 +76,7 @@ class WorkflowAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import override
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
@ -31,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
) -> None:
self._scope_getter = scope_getter
@override
def current_scope(self) -> FileAccessScope | None:
return self._scope_getter()
@override
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
@ -62,6 +65,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
)
)
@override
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
@ -78,6 +82,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
@override
def get_upload_file(
self,
*,
@ -95,6 +100,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
)
return session.scalar(stmt)
@override
def get_tool_file(
self,
*,

View File

@ -8,6 +8,7 @@ scope updates that matter to chat applications.
"""
import logging
from typing import override
from core.workflow.system_variables import SystemVariableKey, get_system_text
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
@ -23,9 +24,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
@override
def on_graph_start(self) -> None:
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunVariableUpdatedEvent):
return
@ -44,5 +47,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self
from typing import Annotated, Literal, Self, override
from pydantic import BaseModel, Field
from sqlalchemy import Engine
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@override
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
pause_reasons=event.reasons,
)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.

View File

@ -1,3 +1,5 @@
from typing import override
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
super().__init__()
self._paused = False
@override
def on_graph_start(self):
self._paused = False
@override
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
if isinstance(event, GraphRunPausedEvent):
self._paused = True
@override
def on_graph_end(self, error: Exception | None):
""" """
self._paused = False

View File

@ -1,6 +1,6 @@
import logging
import uuid
from typing import ClassVar
from typing import ClassVar, override
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
@override
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
id=self.schedule_id,
)
@override
def on_event(self, event: GraphEngineEvent):
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True
# remove the scheduler

View File

@ -1,6 +1,6 @@
import logging
from datetime import UTC, datetime
from typing import Any, ClassVar
from typing import Any, ClassVar, override
from pydantic import TypeAdapter
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@override
def on_graph_start(self):
pass
@override
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
repo.update(trigger_log)
session.commit()
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -7,7 +7,7 @@ import os
import time
import urllib.parse
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, override
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
self._file_access_controller = file_access_controller
@property
@override
def multimodal_send_format(self) -> str:
return dify_config.MULTIMODAL_SEND_FORMAT
@override
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
@override
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
@override
def load_file_bytes(self, *, file: File) -> bytes:
storage_key = self._resolve_storage_key(file=file)
data = storage.load(storage_key, stream=False)
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
raise ValueError(f"file {storage_key} is not a bytes object")
return data
@override
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
return file.remote_url
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
)
return None
@override
def resolve_upload_file_url(
self,
*,
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
query["as_attachment"] = "true"
return f"{url}?{urllib.parse.urlencode(query)}"
@override
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._assert_tool_file_access(tool_file_id=tool_file_id)
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
@override
def verify_preview_signature(
self,
*,

View File

@ -12,7 +12,7 @@ state.
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from typing import Any, Union, override
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.helper.trace_id_helper import ParentTraceContext
@ -98,12 +98,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
@override
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
@override
def on_event(self, event: GraphEngineEvent) -> None:
match event:
case GraphRunStartedEvent():
@ -131,6 +133,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
case NodeRunPauseRequestedEvent():
self._handle_node_pause_requested(event)
@override
def on_graph_end(self, error: Exception | None) -> None:
return

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
@override
def get_icon_url(self, tenant_id: str) -> str:
return self.icon

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -67,5 +68,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -6,7 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from sqlalchemy import func, select
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
key = str(ModelProviderID(key))
return key in self.configurations
@override
def __iter__(self):
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any, TypedDict, override
from sqlalchemy import select
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@override
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,4 +1,5 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.code_executor import CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class JavascriptCodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.JAVASCRIPT
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,10 +1,12 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from typing import Any, override
from core.helper.code_executor.template_transformer import TemplateTransformer
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
@override
def transform_response(cls, response: str):
"""
Transform response to dict
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
@override
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return script
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
import jinja2
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return runner_script
@classmethod
@override
def get_preload_script(cls) -> str:
preload_script = dedent("""
import jinja2

View File

@ -1,4 +1,5 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.code_executor import CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class Python3CodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.PYTHON3
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,10 +1,12 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from extensions.ext_redis import redis_client
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
provider_identity=provider_identity,
)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]

View File

@ -43,13 +43,16 @@ request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]:
"""Build per-scheme proxy transports with the same TLS policy as the SSRF client."""
return {
"http://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTP_URL,
verify=verify,
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
verify=verify,
),
}
@ -64,7 +67,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client:
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
return httpx.Client(
mounts=_create_proxy_mounts(),
mounts=_create_proxy_mounts(verify=verify),
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)

View File

@ -2,6 +2,7 @@
import contextlib
import logging
from typing import override
import flask
@ -15,6 +16,7 @@ class TraceContextFilter(logging.Filter):
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
@ -54,6 +56,7 @@ class IdentityContextFilter(logging.Filter):
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, NotRequired, TypedDict
from typing import Any, NotRequired, TypedDict, override
import orjson
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
@override
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:

View File

@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from typing import Any
from typing import Any, override
from sqlalchemy.orm import Session
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
# Reset retry flag after operation completes
self._has_retried = False
@override
def __enter__(self):
"""Enter the context manager with retry support."""
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
return self._execute_with_retry(initialize_with_retry)
@override
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
"""
return self._execute_with_retry(super().list_tools)
@override
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.

View File

@ -1,6 +1,6 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
from typing import Any, Protocol, override
from pydantic import AnyUrl, TypeAdapter
@ -159,6 +159,7 @@ class ClientSession(
types.EmptyResult,
)
@override
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""Send a progress notification."""
self.send_notification(
@ -326,6 +327,7 @@ class ClientSession(
)
)
@override
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
@ -351,6 +353,7 @@ class ClientSession(
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
@override
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@ -358,6 +361,7 @@ class ClientSession(
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
@override
def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server."""
# Process specific notification types

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -25,6 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -43,6 +44,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -59,6 +61,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any
from typing import Any, override
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -8,6 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,6 +29,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -49,6 +51,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.model_manager import ModelManager
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -9,6 +9,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -19,6 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -36,6 +38,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,4 +1,5 @@
from collections.abc import Mapping
from typing import override
from pydantic import TypeAdapter
@ -11,6 +12,7 @@ class PluginDaemonError(Exception):
def __init__(self, description: str):
self.description = description
@override
def __str__(self) -> str:
# returns the class name and description
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"

View File

@ -3,8 +3,7 @@ from __future__ import annotations
import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Literal, cast, overload
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from redis import RedisError
@ -13,9 +12,9 @@ from configs import dify_config
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_structured_output as invoke_llm_with_structured_output_helper,
)
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.llm_entities import (
LLMResult,
@ -101,35 +100,38 @@ class _PluginStructuredOutputModelInstance:
class PluginModelRuntime(ModelRuntime):
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope."""
"""Plugin-backed runtime adapter bound to tenant context and optional caller scope.
Provider discovery goes through ``PluginService`` so the plugin lifecycle
methods and provider reads share one tenant-scoped cache owner.
"""
tenant_id: str
user_id: str | None
client: PluginModelClient
_provider_entities: tuple[ProviderEntity, ...] | None
_provider_entities_lock: Lock
_plugin_service: type[PluginService]
def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None:
def __init__(
self,
tenant_id: str,
user_id: str | None,
client: PluginModelClient,
plugin_service: type[PluginService],
) -> None:
if client is None:
raise ValueError("client is required.")
if plugin_service is None:
raise ValueError("plugin_service is required.")
self.tenant_id = tenant_id
self.user_id = user_id
self.client = client
self._provider_entities = None
self._provider_entities_lock = Lock()
self._plugin_service = plugin_service
@override
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
with self._provider_entities_lock:
if self._provider_entities is None:
self._provider_entities = tuple(
self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id)
)
return self._provider_entities
return self._plugin_service.fetch_plugin_model_providers(tenant_id=self.tenant_id, client=self.client)
@override
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime):
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@override
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def validate_model_credentials(
self,
*,
@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def get_model_schema(
self,
*,
@ -294,6 +299,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
@override
def invoke_llm(
self,
*,
@ -357,6 +363,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@override
def invoke_llm_with_structured_output(
self,
*,
@ -396,6 +403,7 @@ class PluginModelRuntime(ModelRuntime):
stream=stream,
)
@override
def get_llm_num_tokens(
self,
*,
@ -422,6 +430,7 @@ class PluginModelRuntime(ModelRuntime):
tools=list(tools) if tools else None,
)
@override
def invoke_text_embedding(
self,
*,
@ -443,6 +452,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def invoke_multimodal_embedding(
self,
*,
@ -464,6 +474,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def get_text_embedding_num_tokens(
self,
*,
@ -483,6 +494,7 @@ class PluginModelRuntime(ModelRuntime):
texts=texts,
)
@override
def invoke_rerank(
self,
*,
@ -508,6 +520,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_multimodal_rerank(
self,
*,
@ -533,6 +546,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_tts(
self,
*,
@ -554,6 +568,7 @@ class PluginModelRuntime(ModelRuntime):
voice=voice,
)
@override
def get_tts_model_voices(
self,
*,
@ -573,6 +588,7 @@ class PluginModelRuntime(ModelRuntime):
language=language,
)
@override
def invoke_speech_to_text(
self,
*,
@ -592,6 +608,7 @@ class PluginModelRuntime(ModelRuntime):
file=file,
)
@override
def invoke_moderation(
self,
*,
@ -611,34 +628,6 @@ class PluginModelRuntime(ModelRuntime):
text=text,
)
def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = self._get_provider_short_name_alias(provider)
return declaration
def _get_provider_schema(self, provider: str) -> ProviderEntity:
providers = self.fetch_model_providers()
provider_entity = next((item for item in providers if item.provider == provider), None)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.plugin.impl.model import PluginModelClient
from core.plugin.plugin_service import PluginService
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from graphon.model_runtime.model_providers.base.ai_model import AIModel
@ -117,6 +118,7 @@ def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -
tenant_id=tenant_id,
user_id=user_id,
client=PluginModelClient(),
plugin_service=PluginService,
)

View File

@ -1,8 +1,17 @@
"""Core plugin service and tenant-scoped plugin metadata cache ownership.
This module owns plugin daemon management calls that are shared by API services
and core runtimes. Plugin model provider discovery is cached here, alongside
plugin install, uninstall, and upgrade invalidation, so all cache mutations for
plugin-owned provider metadata stay tenant-scoped and in one place.
"""
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from pydantic import BaseModel
from pydantic import BaseModel, TypeAdapter, ValidationError
from redis import RedisError
from sqlalchemy import delete, select, update
from sqlalchemy.orm import Session
from yarl import URL
@ -22,16 +31,20 @@ from core.plugin.entities.plugin import (
from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse,
PluginInstallTask,
PluginInstallTaskStatus,
PluginListResponse,
PluginModelProviderEntity,
PluginVerification,
)
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.model import PluginModelClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from models.provider_ids import GenericProviderID
from models.provider_ids import GenericProviderID, ModelProviderID
from services.enterprise.plugin_manager_service import (
PluginManagerService,
PreUninstallPluginRequest,
@ -40,6 +53,7 @@ from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
logger = logging.getLogger(__name__)
_provider_entities_adapter: TypeAdapter[list[ProviderEntity]] = TypeAdapter(list[ProviderEntity])
class PluginService:
@ -53,6 +67,102 @@ class PluginService:
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX = "plugin_model_providers:tenant_id:"
PLUGIN_INSTALL_TASK_TERMINAL_STATUSES = (PluginInstallTaskStatus.Success, PluginInstallTaskStatus.Failed)
@classmethod
def _get_plugin_model_providers_cache_key(cls, tenant_id: str) -> str:
return f"{cls.PLUGIN_MODEL_PROVIDERS_REDIS_KEY_PREFIX}{tenant_id}"
@staticmethod
def _get_provider_short_name_alias(provider: PluginModelProviderEntity) -> str:
"""
Expose a bare provider alias only for the canonical provider mapping.
Multiple plugins can publish the same short provider slug. If every
provider entity keeps that slug in ``provider_name``, callers that still
resolve by short name become order-dependent. Restrict the alias to the
provider selected by ``ModelProviderID`` so legacy short-name lookups
remain deterministic while the runtime surface stays canonical.
"""
try:
canonical_provider_id = ModelProviderID(provider.provider)
except ValueError:
return ""
if canonical_provider_id.plugin_id != provider.plugin_id:
return ""
if canonical_provider_id.provider_name != provider.provider:
return ""
return provider.provider
@classmethod
def _to_provider_entity(cls, provider: PluginModelProviderEntity) -> ProviderEntity:
declaration = provider.declaration.model_copy(deep=True)
declaration.provider = f"{provider.plugin_id}/{provider.provider}"
declaration.provider_name = cls._get_provider_short_name_alias(provider)
return declaration
@classmethod
def _load_cached_plugin_model_providers(cls, tenant_id: str) -> tuple[ProviderEntity, ...] | None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
try:
cached_providers = redis_client.get(cache_key)
except (RedisError, RuntimeError):
logger.warning("Failed to read cached plugin model providers for tenant %s.", tenant_id, exc_info=True)
return None
if not cached_providers:
return None
try:
return tuple(_provider_entities_adapter.validate_json(cached_providers))
except (TypeError, ValueError, ValidationError):
logger.warning(
"Invalid cached plugin model providers for tenant %s; deleting cache.", tenant_id, exc_info=True
)
cls.invalidate_plugin_model_providers_cache(tenant_id)
return None
@classmethod
def _store_cached_plugin_model_providers(cls, tenant_id: str, providers: Sequence[ProviderEntity]) -> None:
cache_key = cls._get_plugin_model_providers_cache_key(tenant_id)
try:
payload = _provider_entities_adapter.dump_json(list(providers)).decode("utf-8")
redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_PROVIDERS_CACHE_TTL, payload)
except (RedisError, RuntimeError):
logger.warning("Failed to cache plugin model providers for tenant %s.", tenant_id, exc_info=True)
@classmethod
def invalidate_plugin_model_providers_cache(cls, tenant_id: str) -> None:
"""Delete the tenant-scoped plugin model provider list cache."""
try:
redis_client.delete(cls._get_plugin_model_providers_cache_key(tenant_id))
except (RedisError, RuntimeError):
logger.warning("Failed to invalidate plugin model providers cache for tenant %s.", tenant_id, exc_info=True)
@classmethod
def fetch_plugin_model_providers(
cls, *, tenant_id: str, client: PluginModelClient | None = None
) -> Sequence[ProviderEntity]:
"""
Fetch plugin model providers through the tenant-scoped plugin cache.
Plugin daemon provider discovery and plugin lifecycle cache invalidation
are intentionally owned by this service so tenant isolation and cache
expiry are handled in one place.
"""
cached_providers = cls._load_cached_plugin_model_providers(tenant_id)
if cached_providers is not None:
return cached_providers
model_client = client or PluginModelClient()
providers = tuple(
cls._to_provider_entity(provider) for provider in model_client.fetch_model_providers(tenant_id)
)
cls._store_cached_plugin_model_providers(tenant_id, providers)
return providers
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
@ -248,12 +358,18 @@ class PluginService:
Fetch plugin installation tasks
"""
manager = PluginInstaller()
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
tasks = manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
if any(task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES for task in tasks):
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return tasks
@staticmethod
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
manager = PluginInstaller()
return manager.fetch_plugin_installation_task(tenant_id, task_id)
task = manager.fetch_plugin_installation_task(tenant_id, task_id)
if task.status in PluginService.PLUGIN_INSTALL_TASK_TERMINAL_STATUSES:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return task
@staticmethod
def delete_install_task(tenant_id: str, task_id: str) -> bool:
@ -315,7 +431,7 @@ class PluginService:
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
return manager.upgrade_plugin(
result = manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
@ -324,6 +440,8 @@ class PluginService:
"plugin_unique_identifier": new_plugin_unique_identifier,
},
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def upgrade_plugin_with_github(
@ -339,7 +457,7 @@ class PluginService:
"""
PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
return manager.upgrade_plugin(
result = manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
new_plugin_unique_identifier,
@ -350,6 +468,8 @@ class PluginService:
"package": package,
},
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
@ -415,12 +535,14 @@ class PluginService:
resp = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(resp.verification)
return manager.install_from_identifiers(
result = manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
PluginInstallationSource.Package,
[{}],
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def install_from_github(tenant_id: str, plugin_unique_identifier: str, repo: str, version: str, package: str):
@ -434,7 +556,7 @@ class PluginService:
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
return manager.install_from_identifiers(
result = manager.install_from_identifiers(
tenant_id,
[plugin_unique_identifier],
PluginInstallationSource.Github,
@ -446,6 +568,8 @@ class PluginService:
}
],
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
@ -513,12 +637,14 @@ class PluginService:
actual_plugin_unique_identifiers.append(response.unique_identifier)
metas.append({"plugin_unique_identifier": response.unique_identifier})
return manager.install_from_identifiers(
result = manager.install_from_identifiers(
tenant_id,
actual_plugin_unique_identifiers,
PluginInstallationSource.Marketplace,
metas,
)
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
@ -529,7 +655,10 @@ class PluginService:
plugin = next((p for p in plugins if p.installation_id == plugin_installation_id), None)
if not plugin:
return manager.uninstall(tenant_id, plugin_installation_id)
result = manager.uninstall(tenant_id, plugin_installation_id)
if result:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
if dify_config.ENTERPRISE_ENABLED:
PluginManagerService.try_pre_uninstall_plugin(
@ -559,37 +688,39 @@ class PluginService:
if not credential_ids:
logger.info("No credentials found for plugin: %s", plugin_id)
return manager.uninstall(tenant_id, plugin_installation_id)
else:
provider_ids = session.scalars(
select(Provider.id).where(
Provider.tenant_id == tenant_id,
Provider.provider_name.like(f"{plugin_id}/%"),
Provider.credential_id.in_(credential_ids),
)
).all()
provider_ids = session.scalars(
select(Provider.id).where(
Provider.tenant_id == tenant_id,
Provider.provider_name.like(f"{plugin_id}/%"),
Provider.credential_id.in_(credential_ids),
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
for provider_id in provider_ids:
ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
).delete()
session.execute(
delete(ProviderCredential).where(
ProviderCredential.id.in_(credential_ids),
)
)
).all()
session.execute(update(Provider).where(Provider.id.in_(provider_ids)).values(credential_id=None))
for provider_id in provider_ids:
ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=provider_id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
).delete()
session.execute(
delete(ProviderCredential).where(
ProviderCredential.id.in_(credential_ids),
logger.info(
"Completed deleting credentials and cleaning provider associations for plugin: %s",
plugin_id,
)
)
logger.info(
"Completed deleting credentials and cleaning provider associations for plugin: %s",
plugin_id,
)
return manager.uninstall(tenant_id, plugin_installation_id)
result = manager.uninstall(tenant_id, plugin_installation_id)
if result:
PluginService.invalidate_plugin_model_providers_cache(tenant_id)
return result
@staticmethod
def check_tools_existence(tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, TypedDict
from typing import Any, TypedDict, override
import orjson
from pydantic import BaseModel
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
super().__init__(dataset)
self._config = KeywordTableConfig()
@override
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
return self
@override
def add_texts(self, texts: list[Document], **kwargs):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
return False
return id in set.union(*keyword_table.values())
@override
def delete_by_ids(self, ids: list[str]):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
return documents
@override
def delete(self):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):

View File

@ -2,7 +2,7 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, override
from sqlalchemy import select
@ -72,21 +72,27 @@ class _LazyEmbeddings(Embeddings):
self._real = CacheEmbedding(embedding_model)
return self._real
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
@override
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
@override
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
@override
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)
@override
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return await self._ensure().aembed_documents(texts)
@override
async def aembed_query(self, text: str) -> list[float]:
return await self._ensure().aembed_query(text)

View File

@ -16,6 +16,7 @@ from core.plugin.entities.request import (
TriggerSubscriptionResponse,
)
from core.plugin.impl.trigger import PluginTriggerClient
from core.plugin.plugin_service import PluginService
from core.trigger.entities.api_entities import EventApiEntity, TriggerProviderApiEntity
from core.trigger.entities.entities import (
EventEntity,
@ -30,7 +31,6 @@ from core.trigger.entities.entities import (
)
from core.trigger.errors import TriggerProviderCredentialValidationError
from models.provider_ids import TriggerProviderID
from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)

View File

@ -37,6 +37,10 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from core.workflow.nodes.agent_v2 import DifyAgentNode
from core.workflow.nodes.agent_v2.binding_resolver import WorkflowAgentBindingResolver
from core.workflow.nodes.agent_v2.output_adapter import WorkflowAgentOutputAdapter
from core.workflow.nodes.agent_v2.runtime_request_builder import WorkflowAgentRuntimeRequestBuilder
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
from graphon.entities.base_node_data import BaseNodeData
@ -438,12 +442,7 @@ class DifyNodeFactory(NodeFactory):
"tool_file_manager": self._bound_tool_file_manager_factory(),
"runtime": self._tool_runtime,
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
BuiltinNodeTypes.AGENT: lambda: self._build_agent_node_init_kwargs(node_class=node_class),
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
@ -469,6 +468,32 @@ class DifyNodeFactory(NodeFactory):
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
def _build_agent_node_init_kwargs(self, *, node_class: type[Node]) -> dict[str, object]:
if issubclass(node_class, DifyAgentNode):
from clients.agent_backend import AgentBackendRunEventAdapter, AgentBackendRunRequestBuilder
from clients.agent_backend.factory import create_agent_backend_run_client
return {
"binding_resolver": WorkflowAgentBindingResolver(),
"runtime_request_builder": WorkflowAgentRuntimeRequestBuilder(
credentials_provider=self._llm_credentials_provider,
request_builder=AgentBackendRunRequestBuilder(),
),
"agent_backend_client": create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
),
"event_adapter": AgentBackendRunEventAdapter(),
"output_adapter": WorkflowAgentOutputAdapter(),
}
return {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
}
def _build_llm_compatible_node_init_kwargs(
self,
*,

View File

@ -0,0 +1,4 @@
from .agent_node import DifyAgentNode
from .entities import DifyAgentNodeData
__all__ = ["DifyAgentNode", "DifyAgentNodeData"]

Some files were not shown because too many files have changed in this diff Show More