Compare commits

..

41 Commits

Author SHA1 Message Date
yyh
40d06bd476 refactor(dify-ui): tighten textarea count state 2026-05-23 18:33:04 +08:00
yyh
13e024618a revert(dify-ui): drop switch story change 2026-05-23 18:27:17 +08:00
yyh
5adf995527 lint 2026-05-23 17:43:17 +08:00
yyh
5d4def8298 fix(web): label migrated textarea controls 2026-05-23 17:42:11 +08:00
yyh
d5d0d2d96f feat(dify-ui): add textarea primitive 2026-05-23 17:30:07 +08:00
2a0c098857 refactor: convert isinstance chains to match/case in otel parser (#36534)
Co-authored-by: Cowork 3P <cowork-3p@localhost>
2026-05-22 18:39:24 +00:00
790ca72627 refactor(api): migrate console/service_api.dataset to BaseModel (#36480) 2026-05-22 17:39:07 +00:00
4d8b6c7dc0 refactor: add missing @override decorator to remaining MCP, Jieba, embeddings, and misc subclasses (#36528) 2026-05-22 13:45:35 +00:00
473c945839 chore: seprate vector space quota query (#36514)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-22 09:26:17 +00:00
yyh
a698c60b29 fix(web): stabilize document picker search focus (#36525) 2026-05-22 09:24:37 +00:00
yyh
24bab5fb2a refactor(web): improve retrieval and tag control semantics (#36521) 2026-05-22 09:01:31 +00:00
yyh
93b7a81071 fix(dify-ui): align form label guidance (#36510) 2026-05-22 07:29:57 +00:00
157e6244dd refactor: add missing @override decorator to agent runners, tool caches, and logging extensions (#36511) 2026-05-22 06:41:48 +00:00
yyh
964aaad7ed refactor: streamline workflow context menu lifecycle (#36500) 2026-05-22 04:31:39 +00:00
92181dbe09 fix(api): preserve remote file URL query params (#36478) 2026-05-22 01:45:20 +00:00
30deef45d9 fix(api): pass SSL verify flag to SSRF proxy mounts (#36455) 2026-05-22 01:31:46 +00:00
ee28074390 refactor: add missing @override decorator to Moderation subclasses (#36492) 2026-05-21 19:42:20 +00:00
1fb491337b refactor: add missing @override decorator to datasource plugin classes (#36494) 2026-05-21 19:41:42 +00:00
82b0a03f5a refactor: add missing @override decorator to PluginModelRuntime (#36493) 2026-05-21 19:40:40 +00:00
6185016910 refactor: add missing @override decorator to file access controller and workflow file runtime (#36495) 2026-05-21 19:39:51 +00:00
b4f5f4869f refactor: add missing @override decorator to code executor providers and transformers (#36496) 2026-05-21 19:39:10 +00:00
7ecbed3b04 chore: add Type to test (#36454)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-21 19:25:03 +00:00
5b58defd62 refactor: add missing @override decorator to GraphEngineLayer subclasses (#36491) 2026-05-21 16:32:02 +00:00
73196de5e1 refactor: add missing @override decorator to AppQueueManager subclasses (#36490) 2026-05-21 16:25:07 +00:00
ea5e487d3c fix(api): stop returning 204 with response body and add CI check (#36489) 2026-05-21 16:20:34 +00:00
f19702f76c feat(api): Flask-RESTX response() vs actual return value checker (#36488) 2026-05-21 15:05:06 +00:00
092c8bca81 refactor(api): migrate console.datasets.metadata to BaseModel (#36450)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-21 15:04:42 +00:00
c50d504c44 refactor: add missing @override decorator to AppGenerateResponseConverter subclasses (#36486) 2026-05-21 14:00:22 +00:00
1b4356b66a fix(ci): bad pyinfra type coverage report comments (#36482) 2026-05-21 12:08:24 +00:00
yyh
7f633622aa fix(web): use popup-open selectors for trigger styles (#36471) 2026-05-21 06:13:11 +00:00
yyh
66f5ab4cfc feat: add dify-ui input primitive (#36446) 2026-05-21 03:15:38 +00:00
0cf9597f52 fix: suggested questions API crash on legacy conversation override configs (#36459) 2026-05-21 01:58:52 +00:00
60cd346fa6 feat: wire workflow agent node runtime (#36437)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 12:39:45 +00:00
56d4d54c16 chore: compatiable conversation is not exists (#33274)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-05-20 12:37:33 +00:00
yyh
9f9cb4d17e feat(ui): migrate radio to Base UI and update web callsites (#36451)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 12:05:31 +00:00
yyh
7d0d9019d8 chore: upgrade base ui to 1.5.0 (#36442) 2026-05-20 09:58:08 +00:00
d646bcf257 chore: remove unused pyrefly ignore comments in dataset.py (#36443)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-05-20 09:14:43 +00:00
e3b45a48eb fix: allow config pubsub join timeout for lower post-run latency (#36438)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
2026-05-20 08:45:51 +00:00
848c15a265 chore: update to only SaaS can view template (#36440) 2026-05-20 08:18:26 +00:00
yyh
be8627233d ci: show web test shard failures (#36436) 2026-05-20 08:03:15 +00:00
1fe8b7fb1d fix(auth): use validity-returned token in ChangePasswordForm reset submit (#36415) 2026-05-20 07:59:09 +00:00
446 changed files with 13723 additions and 6757 deletions

View File

@ -63,8 +63,8 @@ jobs:
id: render
run: |
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
--base base_report.json \
< pr_report.json)"
--base "$GITHUB_WORKSPACE/base_report.json" \
< "$GITHUB_WORKSPACE/pr_report.json")"
{
echo "### Pyrefly Type Coverage"

View File

@ -65,6 +65,9 @@ jobs:
# Save structured data for the fork-PR comment workflow
cp /tmp/pyrefly_report_pr.json pr_report.json
cp /tmp/pyrefly_report_base.json base_report.json
# Keep fork-PR comments correct while the trusted workflow_run job is
# still using the default-branch renderer, which resolves --base from api/.
cp /tmp/pyrefly_report_base.json api/base_report.json
- name: Save PR number
run: |
@ -77,6 +80,7 @@ jobs:
path: |
pr_report.json
base_report.json
api/base_report.json
pr_number.txt
- name: Comment PR with type coverage

View File

@ -47,6 +47,10 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Response Contract Linter
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
- name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: make type-check-core

View File

@ -39,7 +39,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Run tests
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
run: vp test run --reporter=blob --reporter=minimal --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}

View File

@ -75,13 +75,19 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
@uv run --project api --dev ruff format ./api
@uv run --project api --dev ruff check --fix ./api
@$(MAKE) api-contract-lint
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
api-contract-lint:
@echo "🔎 Linting Flask response contracts..."
@uv run --project api --dev python api/dev/lint_response_contracts.py
@echo "✅ Response contract lint complete"
type-check:
@echo "📝 Running type checks (pyrefly + mypy)..."
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
@ -191,6 +197,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
@echo " make type-check - Run type checks (pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@ -204,4 +211,4 @@ help:
@echo " make build-push-all - Build and push all Docker images"
# Phony targets
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all

View File

@ -767,6 +767,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while use redis as event bus.
# It's highly recommended to enable this for large deployments.
EVENT_BUS_REDIS_USE_CLUSTERS=false
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true

View File

@ -195,6 +195,7 @@ Before opening a PR / submitting:
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise docstrings and comments.
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,

View File

@ -49,6 +49,7 @@ class AgentBackendModelConfig(BaseModel):
model: str
user_id: str | None = None
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@ -138,6 +139,7 @@ class AgentBackendRunRequestBuilder:
model_provider=run_input.model.model_provider,
model=run_input.model.model,
credentials=run_input.model.credentials,
model_settings=run_input.model.model_settings or None,
),
),
]

View File

@ -1,3 +1,4 @@
from configs.extra.agent_backend_config import AgentBackendConfig
from configs.extra.archive_config import ArchiveStorageConfig
from configs.extra.notion_config import NotionConfig
from configs.extra.sentry_config import SentryConfig
@ -5,6 +6,7 @@ from configs.extra.sentry_config import SentryConfig
class ExtraServiceConfig(
# place the configs in alphabet order
AgentBackendConfig,
ArchiveStorageConfig,
NotionConfig,
SentryConfig,

View File

@ -0,0 +1,23 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class AgentBackendConfig(BaseSettings):
"""
Configuration settings for the Agent backend runtime integration.
"""
AGENT_BACKEND_BASE_URL: str | None = Field(
description="Base URL for the Dify Agent backend service.",
default=None,
)
AGENT_BACKEND_USE_FAKE: bool = Field(
description="Use the deterministic in-process fake Agent backend client.",
default=False,
)
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
description="Scenario used by the fake Agent backend client.",
default="success",
)

View File

@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
from pydantic.types import NonNegativeInt
from pydantic_settings import BaseSettings
@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
default=600,
)
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
description=(
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
"an SSE client and the response stream actually closing.\n\n"
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
"return promptly while the daemon listener thread cleans itself up on the next poll "
"boundary - safe because the listener holds no critical state and exits within one poll "
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
),
default=2000,
)
def _build_default_pubsub_url(self) -> str:
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:

View File

@ -36,6 +36,24 @@ class FileInfo(BaseModel):
size: int
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
decoded_url = urllib.parse.unquote(url)
if isinstance(query_string, bytes):
raw_query = query_string.decode()
else:
raw_query = query_string
if not raw_query:
return decoded_url
if decoded_url.endswith(("?", "&")):
separator = ""
elif urllib.parse.urlsplit(decoded_url).query:
separator = "&"
else:
separator = "?"
return f"{decoded_url}{separator}{raw_query}"
def guess_file_info_from_response(response: httpx.Response):
url = str(response.url)
# Try to extract filename from URL

View File

@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource):
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
db.session.commit()
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:resource_id>/api-keys")

View File

@ -269,12 +269,12 @@ class AnnotationApi(Resource):
"message": "annotation_ids are required if the parameter is provided.",
}, 400
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return result, 204
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
return "", 204
# If no annotation_ids are provided, handle clearing all annotations
else:
AppAnnotationService.clear_all_annotations(str(app_id))
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def delete(self, app_id: UUID, annotation_id: UUID):
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")

View File

@ -633,7 +633,7 @@ class AppApi(Resource):
app_service = AppService()
app_service.delete_app(app_model)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/copy")

View File

@ -29,9 +29,6 @@ from fields.conversation_fields import (
from fields.conversation_fields import (
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
)
from fields.conversation_fields import (
ResultResponse,
)
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.login import current_account_with_tenant, login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@ -77,7 +74,6 @@ register_schema_models(
ConversationMessageDetailResponse,
ConversationWithSummaryPaginationResponse,
ConversationDetailResponse,
ResultResponse,
CompletionConversationQuery,
ChatConversationQuery,
)
@ -194,7 +190,7 @@ class CompletionConversationDetailApi(Resource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
@ -347,7 +343,7 @@ class ChatConversationDetailApi(Resource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
def _get_conversation(app_model, conversation_id):

View File

@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204
return "", 204
except Exception as e:
raise BadRequest(str(e))

View File

@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
user_id=current_user.id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
user_id=current_user.id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")

View File

@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
return {"result": "success"}, 204
return "", 204

View File

@ -1,15 +1,16 @@
from typing import Any, cast
from datetime import datetime
from typing import Any
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.app.error import ProviderNotInitializeError
@ -30,26 +31,10 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import app_detail_kernel_fields, related_app_list
from fields.dataset_fields import (
content_fields,
dataset_detail_fields,
dataset_fields,
dataset_query_detail_fields,
dataset_retrieval_model_fields,
doc_metadata_fields,
external_knowledge_info_fields,
external_retrieval_model_fields,
file_info_fields,
icon_info_fields,
keyword_setting_fields,
reranking_model_fields,
tag_fields,
vector_setting_fields,
weighted_score_fields,
)
from fields.document_fields import document_status_fields
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import build_icon_url, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from libs.url_utils import normalize_api_base_url
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
@ -61,58 +46,6 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
tag_model = get_or_create_model("Tag", tag_fields)
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
weighted_score_fields_copy = weighted_score_fields.copy()
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
dataset_detail_fields_copy = dataset_detail_fields.copy()
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
content_fields_copy = content_fields.copy()
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
content_model = get_or_create_model("DatasetContent", content_fields_copy)
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
related_app_list_copy = related_app_list.copy()
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
def _validate_indexing_technique(value: str | None) -> str | None:
if value is None:
@ -208,9 +141,165 @@ class ConsoleDatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetListItemResponse(DatasetDetailResponse):
partial_member_list: list[str]
class DatasetListResponse(ResponseModel):
data: list[DatasetListItemResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
class DatasetQueryFileInfoResponse(ResponseModel):
id: str
name: str
size: int
extension: str
mime_type: str
source_url: str
class DatasetQueryContentResponse(ResponseModel):
content_type: str
content: str
file_info: DatasetQueryFileInfoResponse | None = None
class DatasetQueryDetailResponse(ResponseModel):
id: str
queries: list[DatasetQueryContentResponse]
source: str
source_app_id: str | None
created_by_role: str
created_by: str
created_at: int
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DatasetQueryListResponse(ResponseModel):
data: list[DatasetQueryDetailResponse]
has_more: bool
limit: int
total: int
page: int
class RelatedAppResponse(ResponseModel):
id: str
name: str
description: str
mode: str = Field(validation_alias="mode_compatible_with_agent")
icon_type: str | None
icon: str | None
icon_background: str | None
icon_url: str | None = None
@model_validator(mode="after")
def _set_icon_url(self) -> "RelatedAppResponse":
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
return self
class RelatedAppListResponse(ResponseModel):
data: list[RelatedAppResponse]
total: int
class DocumentStatusResponse(ResponseModel):
id: str
indexing_status: str
processing_started_at: int | None
parsing_completed_at: int | None
cleaning_completed_at: int | None
splitting_completed_at: int | None
completed_at: int | None
paused_at: int | None
error: str | None
stopped_at: int | None
completed_segments: int | None = None
total_segments: int | None = None
@field_validator(
"processing_started_at",
"parsing_completed_at",
"cleaning_completed_at",
"splitting_completed_at",
"completed_at",
"paused_at",
"stopped_at",
mode="before",
)
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return to_timestamp(value)
class DocumentStatusListResponse(ResponseModel):
data: list[DocumentStatusResponse]
class ErrorDocsResponse(DocumentStatusListResponse):
total: int
class IndexingEstimatePreviewItemResponse(ResponseModel):
content: str
child_chunks: list[str] | None = None
summary: str | None = None
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
question: str
answer: str
class IndexingEstimateResponse(ResponseModel):
total_segments: int
preview: list[IndexingEstimatePreviewItemResponse]
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
class RetrievalSettingResponse(ResponseModel):
retrieval_method: list[str]
class PartialMemberListResponse(ResponseModel):
data: list[str]
class AutoDisableLogsResponse(ResponseModel):
document_ids: list[str]
count: int
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
register_response_schema_models(
console_ns,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetQueryListResponse,
IndexingEstimateResponse,
RelatedAppListResponse,
DocumentStatusListResponse,
ErrorDocsResponse,
RetrievalSettingResponse,
PartialMemberListResponse,
AutoDisableLogsResponse,
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -293,17 +382,8 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
class DatasetListApi(Resource):
@console_ns.doc("get_datasets")
@console_ns.doc(description="Get list of datasets")
@console_ns.doc(
params={
"page": "Page number (default: 1)",
"limit": "Number of items per page (default: 20)",
"ids": "Filter by dataset IDs (list)",
"keyword": "Search keyword",
"tag_ids": "Filter by tag IDs (list)",
"include_all": "Include all datasets (default: false)",
}
)
@console_ns.response(200, "Datasets retrieved successfully")
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
@setup_required
@login_required
@account_initialization_required
@ -342,7 +422,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
partial_members_map: dict[str, list[str]] = {}
if dataset_ids:
@ -379,12 +459,12 @@ class DatasetListApi(Resource):
"total": total,
"page": query.page,
}
return response, 200
return dump_response(DatasetListResponse, response), 200
@console_ns.doc("create_dataset")
@console_ns.doc(description="Create a new dataset")
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
@console_ns.response(201, "Dataset created successfully")
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
@console_ns.response(400, "Invalid request parameters")
@setup_required
@login_required
@ -413,7 +493,7 @@ class DatasetListApi(Resource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 201
return dump_response(DatasetDetailResponse, dataset), 201
@console_ns.route("/datasets/<uuid:dataset_id>")
@ -421,7 +501,11 @@ class DatasetApi(Resource):
@console_ns.doc("get_dataset")
@console_ns.doc(description="Get dataset details")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
@console_ns.response(
200,
"Dataset retrieved successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -437,7 +521,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
data = dump_response(DatasetDetailResponse, dataset)
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@ -470,7 +554,11 @@ class DatasetApi(Resource):
@console_ns.doc("update_dataset")
@console_ns.doc(description="Update dataset details")
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
@console_ns.response(
200,
"Dataset updated successfully",
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -506,7 +594,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
result_data = dump_response(DatasetDetailResponse, dataset)
tenant_id = current_tenant_id
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
@ -535,7 +623,7 @@ class DatasetApi(Resource):
try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204
return "", 204
else:
raise NotFound("Dataset not found.")
except services.errors.dataset.DatasetInUseError:
@ -567,7 +655,11 @@ class DatasetQueryApi(Resource):
@console_ns.doc("get_dataset_queries")
@console_ns.doc(description="Get dataset query history")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
@console_ns.response(
200,
"Query history retrieved successfully",
console_ns.models[DatasetQueryListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -589,20 +681,24 @@ class DatasetQueryApi(Resource):
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
response = {
"data": marshal(dataset_queries, dataset_query_detail_model),
"data": dataset_queries,
"has_more": len(dataset_queries) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
return dump_response(DatasetQueryListResponse, response), 200
@console_ns.route("/datasets/indexing-estimate")
class DatasetIndexingEstimateApi(Resource):
@console_ns.doc("estimate_dataset_indexing")
@console_ns.doc(description="Estimate dataset indexing cost")
@console_ns.response(200, "Indexing estimate calculated successfully")
@console_ns.response(
200,
"Indexing estimate calculated successfully",
console_ns.models[IndexingEstimateResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -699,11 +795,14 @@ class DatasetRelatedAppListApi(Resource):
@console_ns.doc("get_dataset_related_apps")
@console_ns.doc(description="Get applications related to dataset")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
@console_ns.response(
200,
"Related apps retrieved successfully",
console_ns.models[RelatedAppListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@marshal_with(related_app_list_model)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -724,7 +823,7 @@ class DatasetRelatedAppListApi(Resource):
if app_model:
related_apps.append(app_model)
return {"data": related_apps, "total": len(related_apps)}, 200
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
@ -732,7 +831,11 @@ class DatasetIndexingStatusApi(Resource):
@console_ns.doc("get_dataset_indexing_status")
@console_ns.doc(description="Get dataset indexing status")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Indexing status retrieved successfully")
@console_ns.response(
200,
"Indexing status retrieved successfully",
console_ns.models[DocumentStatusListResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
@ -778,9 +881,8 @@ class DatasetIndexingStatusApi(Resource):
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status}
return data, 200
documents_status.append(document_dict)
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
@console_ns.route("/datasets/api-keys")
@ -873,7 +975,7 @@ class DatasetApiDeleteApi(Resource):
db.session.delete(key)
db.session.commit()
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
@ -907,13 +1009,18 @@ class DatasetApiBaseUrlApi(Resource):
class DatasetRetrievalSettingApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting")
@console_ns.doc(description="Get dataset retrieval settings")
@console_ns.response(200, "Retrieval settings retrieved successfully")
@console_ns.response(
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
)
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = dify_config.VECTOR_STORE
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
)
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
@ -921,12 +1028,19 @@ class DatasetRetrievalSettingMockApi(Resource):
@console_ns.doc("get_dataset_retrieval_setting_mock")
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
@console_ns.doc(params={"vector_type": "Vector store type"})
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
@console_ns.response(
200,
"Mock retrieval settings retrieved successfully",
console_ns.models[RetrievalSettingResponse.__name__],
)
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
return dump_response(
RetrievalSettingResponse,
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
)
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
@ -934,7 +1048,7 @@ class DatasetErrorDocs(Resource):
@console_ns.doc("get_dataset_error_docs")
@console_ns.doc(description="Get dataset error documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Error documents retrieved successfully")
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@ -946,7 +1060,7 @@ class DatasetErrorDocs(Resource):
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
@ -954,7 +1068,11 @@ class DatasetPermissionUserListApi(Resource):
@console_ns.doc("get_dataset_permission_users")
@console_ns.doc(description="Get dataset permission user list")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Permission users retrieved successfully")
@console_ns.response(
200,
"Permission users retrieved successfully",
console_ns.models[PartialMemberListResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@console_ns.response(403, "Permission denied")
@setup_required
@ -973,9 +1091,7 @@ class DatasetPermissionUserListApi(Resource):
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
return {
"data": partial_members_list,
}, 200
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
@ -983,7 +1099,11 @@ class DatasetAutoDisableLogApi(Resource):
@console_ns.doc("get_dataset_auto_disable_logs")
@console_ns.doc(description="Get dataset auto disable logs")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.response(200, "Auto disable logs retrieved successfully")
@console_ns.response(
200,
"Auto disable logs retrieved successfully",
console_ns.models[AutoDisableLogsResponse.__name__],
)
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@ -993,4 +1113,4 @@ class DatasetAutoDisableLogApi(Resource):
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200

View File

@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/init")
@ -966,7 +966,7 @@ class DocumentApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot pause completed document.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Document is not in paused status.")
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource):
# retry document
DocumentService.retry_document(dataset_id, retry_documents)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")

View File

@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {"result": "success"}, 204
return "", 204
@console_ns.route(
@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 204
return "", 204
@setup_required
@login_required

View File

@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")

View File

@ -1,14 +1,18 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from fields.dataset_fields import dataset_metadata_fields
from fields.dataset_fields import (
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from libs.login import current_account_with_tenant, login_required
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
@ -22,7 +26,12 @@ from services.metadata_service import MetadataService
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)
register_response_schema_models(console_ns, SimpleResultResponse)
register_response_schema_models(
console_ns,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -31,7 +40,7 @@ class DatasetMetadataCreateApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
@ -44,18 +53,22 @@ class DatasetMetadataCreateApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return metadata, 201
return dump_response(DatasetMetadataResponse, metadata), 201
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -64,7 +77,7 @@ class DatasetMetadataApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@marshal_with(dataset_metadata_fields)
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
def patch(self, dataset_id, metadata_id):
current_user, _ = current_account_with_tenant()
@ -79,7 +92,7 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
return metadata, 200
return dump_response(DatasetMetadataResponse, metadata), 200
@setup_required
@login_required
@ -96,7 +109,8 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
return {"result": "success"}, 204
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
@console_ns.route("/datasets/metadata/built-in")
@ -105,9 +119,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(
200,
"Built-in fields retrieved successfully",
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self):
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -116,7 +135,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
@login_required
@account_initialization_required
@enterprise_license_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(204, "Action completed successfully")
def post(self, dataset_id, action: Literal["enable", "disable"]):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -130,7 +149,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
return "", 204
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -140,7 +160,10 @@ class DocumentMetadataEditApi(Resource):
@account_initialization_required
@enterprise_license_required
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@console_ns.response(
204,
"Documents metadata updated successfully",
)
def post(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
@ -153,4 +176,5 @@ class DocumentMetadataEditApi(Resource):
MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200
# Frontend callers only await success and invalidate caches; no response body is consumed.
return "", 204

View File

@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@console_ns.route(

View File

@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
db.session.delete(installed_app)
db.session.commit()
return {"result": "success", "message": "App uninstalled successfully"}, 204
return "", 204
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
def patch(self, installed_app):

View File

@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204

View File

@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
return {"result": "success"}, 204
return "", 204

View File

@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
from controllers.common.schema import register_response_schema_models
from libs.login import current_account_with_tenant, current_user, login_required
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
from . import console_ns
from .wraps import account_initialization_required, cloud_utm_record, setup_required
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
@console_ns.route("/features")
@ -28,7 +28,32 @@ class FeatureApi(Resource):
"""Get feature configuration for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_features(current_tenant_id).model_dump()
payload = FeatureService.get_features(
current_tenant_id,
exclude_vector_space=True,
).model_dump()
payload.pop("vector_space", None)
return payload
@console_ns.route("/features/vector-space")
class FeatureVectorSpaceApi(Resource):
@console_ns.doc("get_tenant_feature_vector_space")
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
@console_ns.response(
200,
"Success",
console_ns.models[LimitationModel.__name__],
)
@setup_required
@login_required
@account_initialization_required
@cloud_utm_record
def get(self):
"""Get vector-space usage and limit for current tenant"""
_, current_tenant_id = current_account_with_tenant()
return FeatureService.get_vector_space(current_tenant_id).model_dump()
@console_ns.route("/system-features")

View File

@ -1,6 +1,5 @@
import urllib.parse
import httpx
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
@ -34,7 +33,7 @@ class GetRemoteFileInfo(Resource):
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
@login_required
def get(self, url: str):
decoded_url = urllib.parse.unquote(url)
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)

View File

@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")

View File

@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
credential_id=args.credential_id,
)
return {"result": "success"}, 204
return "", 204
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
from typing import Literal
from typing import Literal, override
from dateutil.parser import isoparse
from flask import request
@ -76,11 +76,13 @@ def _enum_value(value):
class WorkflowRunStatusField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
return _enum_value(obj.status)
class WorkflowRunOutputsField(fields.Raw):
@override
def output(self, key, obj: WorkflowRun, **kwargs):
status = _enum_value(obj.status)
if status == WorkflowExecutionStatus.PAUSED.value:

View File

@ -1,13 +1,17 @@
from typing import Any, Literal, cast
from typing import Any, Literal
from flask import request
from flask_restx import marshal
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
from controllers.common.schema import (
query_params_from_model,
register_enum_models,
register_response_schema_models,
register_schema_models,
)
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
@ -17,9 +21,10 @@ from controllers.service_api.wraps import (
)
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from fields.base import ResponseModel
from fields.dataset_fields import DatasetDetailResponse
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@ -119,6 +124,21 @@ class TagUnbindingPayload(BaseModel):
return self
class KnowledgeTagResponse(ResponseModel):
model_config = ConfigDict(coerce_numbers_to_str=True)
id: str
name: str
type: str
# TODO: The public Service API docs expose binding_count as string|null.
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
binding_count: str | None = None
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
pass
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
@ -127,6 +147,29 @@ class DatasetListQuery(BaseModel):
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
partial_member_list: list[str] | None = None
# todo: duplicate code, but the partial_member_list has different nullability
class DatasetListResponse(ResponseModel):
data: list[DatasetDetailResponse]
has_more: bool
limit: int
total: int
page: int
class DatasetBoundTagResponse(ResponseModel):
id: str
name: str
class DatasetBoundTagListResponse(ResponseModel):
data: list[DatasetBoundTagResponse]
total: int
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@ -137,9 +180,17 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
register_response_schema_models(
service_api_ns,
SimpleResultResponse,
KnowledgeTagResponse,
KnowledgeTagListResponse,
DatasetDetailResponse,
DatasetDetailWithPartialMembersResponse,
DatasetListResponse,
DatasetBoundTagListResponse,
)
@service_api_ns.route("/datasets")
@ -154,9 +205,18 @@ class DatasetListApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
@service_api_ns.response(
200,
"Datasets retrieved successfully",
service_api_ns.models[DatasetListResponse.__name__],
)
def get(self, tenant_id):
"""Resource for getting datasets."""
query = DatasetListQuery.model_validate(request.args.to_dict())
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = DatasetListQuery.model_validate(query_params)
# provider = request.args.get("provider", default="vendor")
datasets, total = DatasetService.get_datasets(
@ -175,22 +235,17 @@ class DatasetListApi(DatasetApiResource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
for item in data:
if (
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
):
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
)
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
if item_model in model_names:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
else:
item["embedding_available"] = False # type: ignore
item["embedding_available"] = False
else:
item["embedding_available"] = True # type: ignore
item["embedding_available"] = True
response = {
"data": data,
"has_more": len(datasets) == query.limit,
@ -198,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
"total": total,
"page": query.page,
}
return response, 200
return dump_response(DatasetListResponse, response), 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
@service_api_ns.doc("create_dataset")
@ -210,6 +265,11 @@ class DatasetListApi(DatasetApiResource):
400: "Bad request - invalid parameters",
}
)
@service_api_ns.response(
200,
"Dataset created successfully",
service_api_ns.models[DatasetDetailResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
@ -253,7 +313,7 @@ class DatasetListApi(DatasetApiResource):
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return marshal(dataset, dataset_detail_fields), 200
return dump_response(DatasetDetailResponse, dataset), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>")
@ -271,6 +331,11 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset retrieved successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -280,7 +345,7 @@ class DatasetApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
data = dump_response(DatasetDetailResponse, dataset)
# check embedding setting
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
@ -312,7 +377,13 @@ class DatasetApi(DatasetApiResource):
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
return (
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
mode="json",
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
),
200,
)
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
@service_api_ns.doc("update_dataset")
@ -326,6 +397,11 @@ class DatasetApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200,
"Dataset updated successfully",
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
@ -376,7 +452,7 @@ class DatasetApi(DatasetApiResource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
result_data = dump_response(DatasetDetailResponse, dataset)
assert isinstance(current_user, Account)
tenant_id = current_user.current_tenant_id
@ -389,7 +465,7 @@ class DatasetApi(DatasetApiResource):
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
@service_api_ns.doc("delete_dataset")
@service_api_ns.doc(description="Delete a dataset")
@ -502,7 +578,7 @@ class DocumentStatusApi(DatasetApiResource):
except ValueError as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
return dump_response(SimpleResultResponse, {"result": "success"}), 200
@service_api_ns.route("/datasets/tags")
@ -515,14 +591,18 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[KnowledgeTagListResponse.__name__],
)
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
return dump_response(KnowledgeTagListResponse, tags), 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@ -534,6 +614,11 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag created successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@ -543,9 +628,10 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
)
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@ -558,6 +644,11 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.response(
200,
"Tag updated successfully",
service_api_ns.models[KnowledgeTagResponse.__name__],
)
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -569,9 +660,10 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
response = dump_response(
KnowledgeTagResponse,
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
)
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
@ -656,6 +748,11 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Tags retrieved successfully",
service_api_ns.models[DatasetBoundTagListResponse.__name__],
)
def get(self, _, *args, **kwargs):
"""Get all knowledge type tags."""
dataset_id = kwargs.get("dataset_id")
@ -663,5 +760,4 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
assert current_user.current_tenant_id is not None
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
response = {"data": tags_list, "total": len(tags)}
return response, 200
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200

View File

@ -1,15 +1,19 @@
from typing import Literal
from flask_login import current_user
from flask_restx import marshal
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
from fields.dataset_fields import (
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
from libs.helper import dump_response
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
@ -27,7 +31,13 @@ register_schema_models(
DocumentMetadataOperation,
MetadataOperationData,
)
register_response_schema_models(service_api_ns, SimpleResultResponse)
register_response_schema_models(
service_api_ns,
DatasetMetadataActionResponse,
DatasetMetadataBuiltInFieldsResponse,
DatasetMetadataListResponse,
DatasetMetadataResponse,
)
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
@ -43,6 +53,9 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
@ -55,7 +68,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
return marshal(metadata, dataset_metadata_fields), 201
return dump_response(DatasetMetadataResponse, metadata), 201
@service_api_ns.doc("get_dataset_metadata")
@service_api_ns.doc(description="Get all metadata for a dataset")
@ -67,13 +80,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
404: "Dataset not found",
}
)
@service_api_ns.response(
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
)
def get(self, tenant_id, dataset_id):
"""Get all metadata for a dataset."""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return MetadataService.get_dataset_metadatas(dataset), 200
metadata = MetadataService.get_dataset_metadatas(dataset)
return dump_response(DatasetMetadataListResponse, metadata), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
@ -89,6 +106,9 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
"""Update metadata name."""
@ -102,7 +122,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
return marshal(metadata, dataset_metadata_fields), 200
return dump_response(DatasetMetadataResponse, metadata), 200
@service_api_ns.doc("delete_dataset_metadata")
@service_api_ns.doc(description="Delete metadata")
@ -114,6 +134,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
404: "Dataset or metadata not found",
}
)
@service_api_ns.response(204, "Metadata deleted successfully")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id):
"""Delete metadata."""
@ -138,10 +159,15 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Built-in fields retrieved successfully",
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
)
def get(self, tenant_id, dataset_id):
"""Get all built-in metadata fields."""
built_in_fields = MetadataService.get_built_in_fields()
return {"fields": built_in_fields}, 200
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
@ -157,9 +183,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
}
)
@service_api_ns.response(
200,
"Action completed successfully",
service_api_ns.models[SimpleResultResponse.__name__],
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
@ -175,7 +199,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
@ -194,7 +218,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
@service_api_ns.response(
200,
"Documents metadata updated successfully",
service_api_ns.models[SimpleResultResponse.__name__],
service_api_ns.models[DatasetMetadataActionResponse.__name__],
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
@ -209,4 +233,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
MetadataService.update_documents_metadata(dataset, metadata_args)
return {"result": "success"}, 200
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200

View File

@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
@service_api_ns.route("/")
class IndexApi(Resource):
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
def get(self):
def get(self) -> dict[str, str]:
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",

View File

@ -136,7 +136,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204
@web_ns.route("/conversations/<uuid:c_id>/name")

View File

@ -1,6 +1,5 @@
import urllib.parse
import httpx
from flask import request
from pydantic import BaseModel, Field, HttpUrl
import services
@ -59,7 +58,7 @@ class RemoteFileInfoApi(WebApiResource):
Raises:
HTTPException: If the remote file cannot be accessed
"""
decoded_url = urllib.parse.unquote(url)
decoded_url = helpers.decode_remote_url(url, request.query_string)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method

View File

@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return ResultResponse(result="success").model_dump(mode="json"), 204
return "", 204

View File

@ -1,4 +1,5 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.file import file_manager
@ -66,6 +67,7 @@ class CotChatAgentRunner(CotAgentRunner):
return prompt_messages
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize

View File

@ -1,4 +1,5 @@
import json
from typing import override
from core.agent.cot_agent_runner import CotAgentRunner
from graphon.model_runtime.entities.message_entities import (
@ -51,6 +52,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
return historic_prompt
@override
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Sequence
from typing import Any
from typing import Any, override
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
self.declaration = declaration
self.meta_version = meta_version
@override
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
return params
@override
def _invoke(
self,
params: dict[str, Any],

View File

@ -55,6 +55,7 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
@ -145,9 +146,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation = None
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
try:
conversation = ConversationService.get_conversation(
app_model=app_model, conversation_id=conversation_id, user=user
)
except ConversationNotExistsError:
if invoke_from == InvokeFrom.SERVICE_API:
conversation = None
else:
raise
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
) -> dict[str, Any]:
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, Any, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking full response.
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
"""
Convert blocking simple response.
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from pydantic import JsonValue
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
@classmethod
@override
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking full response.
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
"""
Convert blocking simple response.
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
return response
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import abc
from collections.abc import Mapping
from typing import Any, Protocol
from typing import Any, Protocol, override
from graphon.enums import NodeType
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
class NoopDraftVariableSaver(DraftVariableSaver):
@override
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
return None

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
self._app_mode = app_mode
self._message_id = str(message_id)
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue

View File

@ -1,3 +1,5 @@
from typing import override
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
self._app_mode = app_mode
@override
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import Any, cast, override
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
):
@classmethod
@override
def convert_blocking_full_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
return dict(blocking_response.model_dump())
@classmethod
@override
def convert_blocking_simple_response(
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
) -> dict[str, Any]:
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
return cls.convert_blocking_full_response(blocking_response)
@classmethod
@override
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:
@ -73,6 +76,7 @@ class WorkflowAppGenerateResponseConverter(
yield response_chunk
@classmethod
@override
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict[str, Any] | str, None, None]:

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Callable
from typing import override
from sqlalchemy import and_, or_, select
from sqlalchemy.orm import Session
@ -31,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
) -> None:
self._scope_getter = scope_getter
@override
def current_scope(self) -> FileAccessScope | None:
return self._scope_getter()
@override
def apply_upload_file_filters(
self,
stmt: Select[tuple[UploadFile]],
@ -62,6 +65,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
)
)
@override
def apply_tool_file_filters(
self,
stmt: Select[tuple[ToolFile]],
@ -78,6 +82,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
@override
def get_upload_file(
self,
*,
@ -95,6 +100,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
)
return session.scalar(stmt)
@override
def get_tool_file(
self,
*,

View File

@ -8,6 +8,7 @@ scope updates that matter to chat applications.
"""
import logging
from typing import override
from core.workflow.system_variables import SystemVariableKey, get_system_text
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
@ -23,9 +24,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
super().__init__()
self._conversation_variable_updater = conversation_variable_updater
@override
def on_graph_start(self) -> None:
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunVariableUpdatedEvent):
return
@ -44,5 +47,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self
from typing import Annotated, Literal, Self, override
from pydantic import BaseModel, Field
from sqlalchemy import Engine
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
def _get_repo(self) -> APIWorkflowRunRepository:
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
@override
def on_graph_start(self) -> None:
"""
Called when graph execution starts.
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
pass
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""
Called for every event emitted by the engine.
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
pause_reasons=event.reasons,
)
@override
def on_graph_end(self, error: Exception | None) -> None:
"""
Called when graph execution ends.

View File

@ -1,3 +1,5 @@
from typing import override
from graphon.graph_engine.layers import GraphEngineLayer
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
super().__init__()
self._paused = False
@override
def on_graph_start(self):
self._paused = False
@override
def on_event(self, event: GraphEngineEvent):
"""
Handle the paused event, stash runtime state into storage and wait for resume.
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
if isinstance(event, GraphRunPausedEvent):
self._paused = True
@override
def on_graph_end(self, error: Exception | None):
""" """
self._paused = False

View File

@ -1,6 +1,6 @@
import logging
import uuid
from typing import ClassVar
from typing import ClassVar, override
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
except Exception:
logger.exception("scheduler error during check if the workflow need to be suspended")
@override
def on_graph_start(self):
"""
Start timer to check if the workflow need to be suspended.
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
id=self.schedule_id,
)
@override
def on_event(self, event: GraphEngineEvent):
pass
@override
def on_graph_end(self, error: Exception | None) -> None:
self.stopped = True
# remove the scheduler

View File

@ -1,6 +1,6 @@
import logging
from datetime import UTC, datetime
from typing import Any, ClassVar
from typing import Any, ClassVar, override
from pydantic import TypeAdapter
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@override
def on_graph_start(self):
pass
@override
def on_event(self, event: GraphEngineEvent):
"""
Update trigger log with success or failure.
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
repo.update(trigger_log)
session.commit()
@override
def on_graph_end(self, error: Exception | None) -> None:
pass

View File

@ -7,7 +7,7 @@ import os
import time
import urllib.parse
from collections.abc import Generator
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, override
from configs import dify_config
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
self._file_access_controller = file_access_controller
@property
@override
def multimodal_send_format(self) -> str:
return dify_config.MULTIMODAL_SEND_FORMAT
@override
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
@override
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
return storage.load(path, stream=stream)
@override
def load_file_bytes(self, *, file: File) -> bytes:
storage_key = self._resolve_storage_key(file=file)
data = storage.load(storage_key, stream=False)
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
raise ValueError(f"file {storage_key} is not a bytes object")
return data
@override
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
if file.transfer_method == FileTransferMethod.REMOTE_URL:
return file.remote_url
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
)
return None
@override
def resolve_upload_file_url(
self,
*,
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
query["as_attachment"] = "true"
return f"{url}?{urllib.parse.urlencode(query)}"
@override
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
self._assert_tool_file_access(tool_file_id=tool_file_id)
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
@override
def verify_preview_signature(
self,
*,

View File

@ -12,7 +12,7 @@ state.
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from typing import Any, Union, override
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.helper.trace_id_helper import ParentTraceContext
@ -98,12 +98,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
@override
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
@override
def on_event(self, event: GraphEngineEvent) -> None:
match event:
case GraphRunStartedEvent():
@ -131,6 +133,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
case NodeRunPauseRequestedEvent():
self._handle_node_pause_requested(event)
@override
def on_graph_end(self, error: Exception | None) -> None:
return

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
self.tenant_id = tenant_id
self.plugin_unique_identifier = plugin_unique_identifier
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE
@override
def get_icon_url(self, tenant_id: str) -> str:
return self.icon

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@override
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,4 +1,5 @@
from collections.abc import Generator
from typing import override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -67,5 +68,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping
from typing import Any
from typing import Any, override
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type,
)
@override
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@ -1,3 +1,5 @@
from typing import override
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
self.plugin_unique_identifier = plugin_unique_identifier
@property
@override
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider

View File

@ -6,7 +6,7 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from sqlalchemy import func, select
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
key = str(ModelProviderID(key))
return key in self.configurations
@override
def __iter__(self):
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any, TypedDict, override
from sqlalchemy import select
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@override
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,4 +1,5 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.code_executor import CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class JavascriptCodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.JAVASCRIPT
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,10 +1,12 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.template_transformer import TemplateTransformer
class NodeJsTemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from textwrap import dedent
from typing import Any
from typing import Any, override
from core.helper.code_executor.template_transformer import TemplateTransformer
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
_template_b64_placeholder: str = "{{template_b64}}"
@classmethod
@override
def transform_response(cls, response: str):
"""
Transform response to dict
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return {"result": cls.extract_result_str_from_response(response)}
@classmethod
@override
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
"""
Override base class to use base64 encoding for template code.
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return script
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f"""
import jinja2
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
return runner_script
@classmethod
@override
def get_preload_script(cls) -> str:
preload_script = dedent("""
import jinja2

View File

@ -1,4 +1,5 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.code_executor import CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
class Python3CodeProvider(CodeNodeProvider):
@staticmethod
@override
def get_language() -> str:
return CodeLanguage.PYTHON3
@classmethod
@override
def get_default_code(cls) -> str:
return dedent(
"""

View File

@ -1,10 +1,12 @@
from textwrap import dedent
from typing import override
from core.helper.code_executor.template_transformer import TemplateTransformer
class Python3TemplateTransformer(TemplateTransformer):
@classmethod
@override
def get_runner_script(cls) -> str:
runner_script = dedent(f""" {cls._code_placeholder}

View File

@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any
from typing import Any, override
from extensions.ext_redis import redis_client
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
provider_identity=provider_identity,
)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
@override
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]

View File

@ -43,13 +43,16 @@ request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]:
"""Build per-scheme proxy transports with the same TLS policy as the SSRF client."""
return {
"http://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTP_URL,
verify=verify,
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
verify=verify,
),
}
@ -64,7 +67,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client:
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
return httpx.Client(
mounts=_create_proxy_mounts(),
mounts=_create_proxy_mounts(verify=verify),
verify=verify,
limits=_SSRF_CLIENT_LIMITS,
)

View File

@ -2,6 +2,7 @@
import contextlib
import logging
from typing import override
import flask
@ -15,6 +16,7 @@ class TraceContextFilter(logging.Filter):
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
@ -54,6 +56,7 @@ class IdentityContextFilter(logging.Filter):
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
@override
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, NotRequired, TypedDict
from typing import Any, NotRequired, TypedDict, override
import orjson
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
@override
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:

View File

@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
import logging
from collections.abc import Callable
from typing import Any
from typing import Any, override
from sqlalchemy.orm import Session
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
# Reset retry flag after operation completes
self._has_retried = False
@override
def __enter__(self):
"""Enter the context manager with retry support."""
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
return self._execute_with_retry(initialize_with_retry)
@override
def list_tools(self) -> list[Tool]:
"""
List available tools from the MCP server with auth retry.
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
"""
return self._execute_with_retry(super().list_tools)
@override
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""
Invoke a tool on the MCP server with auth retry.

View File

@ -1,6 +1,6 @@
import queue
from datetime import timedelta
from typing import Any, Protocol
from typing import Any, Protocol, override
from pydantic import AnyUrl, TypeAdapter
@ -159,6 +159,7 @@ class ClientSession(
types.EmptyResult,
)
@override
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
"""Send a progress notification."""
self.send_notification(
@ -326,6 +327,7 @@ class ClientSession(
)
)
@override
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
@ -351,6 +353,7 @@ class ClientSession(
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
@override
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
@ -358,6 +361,7 @@ class ClientSession(
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
@override
def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server."""
# Process specific notification types

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from pydantic import BaseModel, Field
from sqlalchemy import select
@ -25,6 +25,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -43,6 +44,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -59,6 +61,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any
from typing import Any, override
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -8,6 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -28,6 +29,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -49,6 +51,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, override
from core.model_manager import ModelManager
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
@ -9,6 +9,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
@override
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -19,6 +20,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
@override
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
@ -36,6 +38,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
@override
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""

View File

@ -1,4 +1,5 @@
from collections.abc import Mapping
from typing import override
from pydantic import TypeAdapter
@ -11,6 +12,7 @@ class PluginDaemonError(Exception):
def __init__(self, description: str):
self.description = description
@override
def __str__(self) -> str:
# returns the class name and description
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"

View File

@ -4,7 +4,7 @@ import hashlib
import logging
from collections.abc import Generator, Iterable, Sequence
from threading import Lock
from typing import IO, Any, Literal, cast, overload
from typing import IO, Any, Literal, cast, overload, override
from pydantic import ValidationError
from redis import RedisError
@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime):
self._provider_entities = None
self._provider_entities_lock = Lock()
@override
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
if self._provider_entities is not None:
return self._provider_entities
@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime):
return self._provider_entities
@override
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
provider_schema = self._get_provider_schema(provider)
@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime):
mime_type = image_mime_types.get(extension, "image/png")
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
@override
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
plugin_id, provider_name = self._split_provider(provider)
self.client.validate_provider_credentials(
@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def validate_model_credentials(
self,
*,
@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime):
credentials=credentials,
)
@override
def get_model_schema(
self,
*,
@ -294,6 +299,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...
@override
def invoke_llm(
self,
*,
@ -357,6 +363,7 @@ class PluginModelRuntime(ModelRuntime):
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@override
def invoke_llm_with_structured_output(
self,
*,
@ -396,6 +403,7 @@ class PluginModelRuntime(ModelRuntime):
stream=stream,
)
@override
def get_llm_num_tokens(
self,
*,
@ -422,6 +430,7 @@ class PluginModelRuntime(ModelRuntime):
tools=list(tools) if tools else None,
)
@override
def invoke_text_embedding(
self,
*,
@ -443,6 +452,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def invoke_multimodal_embedding(
self,
*,
@ -464,6 +474,7 @@ class PluginModelRuntime(ModelRuntime):
input_type=input_type,
)
@override
def get_text_embedding_num_tokens(
self,
*,
@ -483,6 +494,7 @@ class PluginModelRuntime(ModelRuntime):
texts=texts,
)
@override
def invoke_rerank(
self,
*,
@ -508,6 +520,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_multimodal_rerank(
self,
*,
@ -533,6 +546,7 @@ class PluginModelRuntime(ModelRuntime):
top_n=top_n,
)
@override
def invoke_tts(
self,
*,
@ -554,6 +568,7 @@ class PluginModelRuntime(ModelRuntime):
voice=voice,
)
@override
def get_tts_model_voices(
self,
*,
@ -573,6 +588,7 @@ class PluginModelRuntime(ModelRuntime):
language=language,
)
@override
def invoke_speech_to_text(
self,
*,
@ -592,6 +608,7 @@ class PluginModelRuntime(ModelRuntime):
file=file,
)
@override
def invoke_moderation(
self,
*,

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, TypedDict
from typing import Any, TypedDict, override
import orjson
from pydantic import BaseModel
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
super().__init__(dataset)
self._config = KeywordTableConfig()
@override
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
return self
@override
def add_texts(self, texts: list[Document], **kwargs):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def text_exists(self, id: str) -> bool:
keyword_table = self._get_dataset_keyword_table()
if keyword_table is None:
return False
return id in set.union(*keyword_table.values())
@override
def delete_by_ids(self, ids: list[str]):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
self._save_dataset_keyword_table(keyword_table)
@override
def search(self, query: str, **kwargs: Any) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
return documents
@override
def delete(self):
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
with redis_client.lock(lock_name, timeout=600):

View File

@ -2,7 +2,7 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, override
from sqlalchemy import select
@ -72,21 +72,27 @@ class _LazyEmbeddings(Embeddings):
self._real = CacheEmbedding(embedding_model)
return self._real
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._ensure().embed_documents(texts)
@override
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
return self._ensure().embed_multimodal_documents(multimodel_documents)
@override
def embed_query(self, text: str) -> list[float]:
return self._ensure().embed_query(text)
@override
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
return self._ensure().embed_multimodal_query(multimodel_document)
@override
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return await self._ensure().aembed_documents(texts)
@override
async def aembed_query(self, text: str) -> list[float]:
return await self._ensure().aembed_query(text)

View File

@ -37,6 +37,10 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from core.workflow.nodes.agent_v2 import DifyAgentNode
from core.workflow.nodes.agent_v2.binding_resolver import WorkflowAgentBindingResolver
from core.workflow.nodes.agent_v2.output_adapter import WorkflowAgentOutputAdapter
from core.workflow.nodes.agent_v2.runtime_request_builder import WorkflowAgentRuntimeRequestBuilder
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
from graphon.entities.base_node_data import BaseNodeData
@ -438,12 +442,7 @@ class DifyNodeFactory(NodeFactory):
"tool_file_manager": self._bound_tool_file_manager_factory(),
"runtime": self._tool_runtime,
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
BuiltinNodeTypes.AGENT: lambda: self._build_agent_node_init_kwargs(node_class=node_class),
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
@ -469,6 +468,32 @@ class DifyNodeFactory(NodeFactory):
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
def _build_agent_node_init_kwargs(self, *, node_class: type[Node]) -> dict[str, object]:
if issubclass(node_class, DifyAgentNode):
from clients.agent_backend import AgentBackendRunEventAdapter, AgentBackendRunRequestBuilder
from clients.agent_backend.factory import create_agent_backend_run_client
return {
"binding_resolver": WorkflowAgentBindingResolver(),
"runtime_request_builder": WorkflowAgentRuntimeRequestBuilder(
credentials_provider=self._llm_credentials_provider,
request_builder=AgentBackendRunRequestBuilder(),
),
"agent_backend_client": create_agent_backend_run_client(
base_url=dify_config.AGENT_BACKEND_BASE_URL,
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
),
"event_adapter": AgentBackendRunEventAdapter(),
"output_adapter": WorkflowAgentOutputAdapter(),
}
return {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
}
def _build_llm_compatible_node_init_kwargs(
self,
*,

View File

@ -0,0 +1,4 @@
from .agent_node import DifyAgentNode
from .entities import DifyAgentNodeData
__all__ = ["DifyAgentNode", "DifyAgentNodeData"]

View File

@ -0,0 +1,281 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from clients.agent_backend import (
AgentBackendError,
AgentBackendHTTPError,
AgentBackendInternalEventType,
AgentBackendRunClient,
AgentBackendRunEventAdapter,
AgentBackendRunFailedInternalEvent,
AgentBackendRunSucceededInternalEvent,
AgentBackendStreamError,
AgentBackendStreamInternalEvent,
AgentBackendTransportError,
AgentBackendValidationError,
)
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from graphon.nodes.base.node import Node
from .binding_resolver import WorkflowAgentBindingError, WorkflowAgentBindingResolver
from .entities import DifyAgentNodeData
from .output_adapter import WorkflowAgentOutputAdapter
from .runtime_request_builder import (
WorkflowAgentRuntimeBuildContext,
WorkflowAgentRuntimeRequestBuilder,
WorkflowAgentRuntimeRequestBuildError,
)
if TYPE_CHECKING:
from graphon.entities import GraphInitParams
from graphon.runtime import GraphRuntimeState
class DifyAgentNode(Node[DifyAgentNodeData]):
node_type = BuiltinNodeTypes.AGENT
def __init__(
self,
node_id: str,
data: DifyAgentNodeData,
*,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
binding_resolver: WorkflowAgentBindingResolver,
runtime_request_builder: WorkflowAgentRuntimeRequestBuilder,
agent_backend_client: AgentBackendRunClient,
event_adapter: AgentBackendRunEventAdapter,
output_adapter: WorkflowAgentOutputAdapter,
) -> None:
super().__init__(
node_id=node_id,
data=data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._binding_resolver = binding_resolver
self._runtime_request_builder = runtime_request_builder
self._agent_backend_client = agent_backend_client
self._event_adapter = event_adapter
self._output_adapter = output_adapter
@classmethod
def version(cls) -> str:
return "2"
def populate_start_event(self, event) -> None:
event.extras["agent_node"] = {"version": "2", "agent_node_kind": self.node_data.agent_node_kind}
def _run(self) -> Generator[NodeEventBase, None, None]:
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
workflow_id = self.graph_init_params.workflow_id
workflow_run_id = get_system_text(
self.graph_runtime_state.variable_pool,
SystemVariableKey.WORKFLOW_EXECUTION_ID,
)
inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
metadata: dict[str, Any] = {
"agent_backend": {
"status": "not_started",
}
}
try:
bundle = self._binding_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
workflow_id=workflow_id,
node_id=self._node_id,
)
runtime_request = self._runtime_request_builder.build(
WorkflowAgentRuntimeBuildContext(
dify_context=dify_ctx,
workflow_id=workflow_id,
workflow_run_id=workflow_run_id,
node_id=self._node_id,
node_execution_id=self.id,
variable_pool=self.graph_runtime_state.variable_pool,
binding=bundle.binding,
agent=bundle.agent,
snapshot=bundle.snapshot,
)
)
inputs = {"agent_backend_request": runtime_request.redacted_request}
metadata = dict(runtime_request.metadata)
process_data = {
"agent_id": bundle.agent.id,
"agent_config_snapshot_id": bundle.snapshot.id,
"binding_id": bundle.binding.id,
}
create_response = self._agent_backend_client.create_run(runtime_request.request)
metadata["agent_backend"] = {
**dict(metadata.get("agent_backend") or {}),
"run_id": create_response.run_id,
"status": create_response.status,
}
except WorkflowAgentBindingError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=error.error_code,
)
return
except WorkflowAgentRuntimeRequestBuildError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=error.error_code,
)
return
except AgentBackendError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=self._agent_backend_error_type(error),
)
return
except Exception as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type="agent_workflow_node_runtime_error",
)
return
stream_event_count = 0
try:
for public_event in self._agent_backend_client.stream_events(create_response.run_id):
stream_event_count += 1
for internal_event in self._event_adapter.adapt(public_event):
if internal_event.type == AgentBackendInternalEventType.RUN_STARTED:
continue
if internal_event.type == AgentBackendInternalEventType.STREAM_EVENT:
if isinstance(internal_event, AgentBackendStreamInternalEvent):
self._record_stream_metadata(metadata, internal_event)
continue
metadata["agent_backend"] = {
**dict(metadata.get("agent_backend") or {}),
"stream_event_count": stream_event_count,
}
if isinstance(internal_event, AgentBackendRunSucceededInternalEvent):
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_success_result(
event=internal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
if isinstance(
internal_event,
AgentBackendRunFailedInternalEvent,
) or internal_event.type in {
AgentBackendInternalEventType.RUN_CANCELLED,
AgentBackendInternalEventType.RUN_PAUSED,
}:
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_failure_result(
event=internal_event,
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
return
except AgentBackendError as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type=self._agent_backend_error_type(error),
)
return
except Exception as error:
yield self._failure_event(
inputs=inputs,
process_data=process_data,
metadata=metadata,
error=str(error),
error_type="agent_backend_stream_error",
)
return
yield StreamCompletedEvent(
node_run_result=self._output_adapter.build_stream_exhausted_result(
inputs=inputs,
process_data=process_data,
metadata=metadata,
)
)
@staticmethod
def _failure_event(
*,
inputs: dict[str, Any],
process_data: dict[str, Any],
metadata: dict[str, Any],
error: str,
error_type: str,
) -> StreamCompletedEvent:
return StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
metadata={WorkflowNodeExecutionMetadataKey.AGENT_LOG: metadata},
outputs={},
error=error,
error_type=error_type,
)
)
@staticmethod
def _agent_backend_error_type(error: AgentBackendError) -> str:
if isinstance(error, AgentBackendValidationError):
return "agent_backend_validation_error"
if isinstance(error, AgentBackendHTTPError):
return "agent_backend_http_error"
if isinstance(error, AgentBackendStreamError):
return "agent_backend_stream_error"
if isinstance(error, AgentBackendTransportError):
return "agent_backend_transport_error"
return "agent_backend_error"
@staticmethod
def _record_stream_metadata(metadata: dict[str, Any], event: AgentBackendStreamInternalEvent) -> None:
agent_backend = dict(metadata.get("agent_backend") or {})
agent_backend["last_stream_event_id"] = event.source_event_id
if event.event_kind:
agent_backend["last_stream_event_kind"] = event.event_kind
if isinstance(event.data, Mapping):
usage = event.data.get("usage") or event.data.get("model_usage")
if isinstance(usage, Mapping):
agent_backend["usage"] = dict(usage)
metadata["agent_backend"] = agent_backend
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: DifyAgentNodeData,
) -> Mapping[str, Sequence[str]]:
del graph_config, node_id, node_data
return {}

View File

@ -0,0 +1,93 @@
from __future__ import annotations
from dataclasses import dataclass
from sqlalchemy import select
from core.db.session_factory import session_factory
from models.agent import Agent, AgentConfigSnapshot, AgentStatus, WorkflowAgentNodeBinding
class WorkflowAgentBindingError(Exception):
error_code: str
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
@dataclass(frozen=True, slots=True)
class WorkflowAgentBindingBundle:
binding: WorkflowAgentNodeBinding
agent: Agent
snapshot: AgentConfigSnapshot
class WorkflowAgentBindingResolver:
"""Resolve the Agent binding owned by the current workflow id and node id."""
def resolve(
self,
*,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> WorkflowAgentBindingBundle:
with session_factory.create_session() as session:
binding = session.scalar(
select(WorkflowAgentNodeBinding)
.where(
WorkflowAgentNodeBinding.tenant_id == tenant_id,
WorkflowAgentNodeBinding.app_id == app_id,
WorkflowAgentNodeBinding.workflow_id == workflow_id,
WorkflowAgentNodeBinding.node_id == node_id,
)
.limit(1)
)
if binding is None:
raise WorkflowAgentBindingError(
"agent_binding_not_found",
f"Workflow Agent binding not found for node {node_id}.",
)
if binding.agent_id is None:
raise WorkflowAgentBindingError("agent_not_available", "Workflow Agent binding has no agent.")
if binding.current_snapshot_id is None:
raise WorkflowAgentBindingError(
"agent_config_snapshot_not_found",
"Workflow Agent binding has no current config snapshot.",
)
agent = session.scalar(
select(Agent)
.where(
Agent.tenant_id == tenant_id,
Agent.id == binding.agent_id,
)
.limit(1)
)
if agent is None or agent.status == AgentStatus.ARCHIVED:
raise WorkflowAgentBindingError(
"agent_not_available",
f"Agent {binding.agent_id} is not available.",
)
snapshot = session.scalar(
select(AgentConfigSnapshot)
.where(
AgentConfigSnapshot.tenant_id == tenant_id,
AgentConfigSnapshot.agent_id == agent.id,
AgentConfigSnapshot.id == binding.current_snapshot_id,
)
.limit(1)
)
if snapshot is None:
raise WorkflowAgentBindingError(
"agent_config_snapshot_not_found",
f"Agent config snapshot {binding.current_snapshot_id} not found.",
)
session.expunge(binding)
session.expunge(agent)
session.expunge(snapshot)
return WorkflowAgentBindingBundle(binding=binding, agent=agent, snapshot=snapshot)

View File

@ -0,0 +1,17 @@
from typing import Literal
from pydantic import model_validator
from graphon.entities.base_node_data import BaseNodeData
from graphon.enums import BuiltinNodeTypes, NodeType
class DifyAgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
agent_node_kind: Literal["dify_agent"] = "dify_agent"
@model_validator(mode="after")
def validate_version(self) -> "DifyAgentNodeData":
if self.version != "2":
raise ValueError("Dify Agent Node v2 requires version='2'")
return self

View File

@ -0,0 +1,255 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from clients.agent_backend import (
AgentBackendInternalEvent,
AgentBackendInternalEventType,
AgentBackendRunCancelledInternalEvent,
AgentBackendRunFailedInternalEvent,
AgentBackendRunPausedInternalEvent,
AgentBackendRunSucceededInternalEvent,
)
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from graphon.file import File, FileTransferMethod, FileType
from graphon.model_runtime.entities.llm_entities import LLMUsage
from graphon.node_events import NodeRunResult
from graphon.variables.segments import ArrayFileSegment, FileSegment
class WorkflowAgentOutputAdapter:
"""Convert terminal Agent backend events into workflow node run results."""
def build_success_result(
self,
*,
event: AgentBackendRunSucceededInternalEvent,
inputs: dict[str, Any],
process_data: dict[str, Any],
metadata: dict[str, Any],
) -> NodeRunResult:
metadata = self._with_terminal_metadata(metadata, event, "succeeded")
usage = self._usage_from_metadata(metadata)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=self._normalize_outputs(event.output),
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
llm_usage=usage or LLMUsage.empty_usage(),
)
def build_failure_result(
self,
*,
event: (
AgentBackendRunFailedInternalEvent
| AgentBackendRunCancelledInternalEvent
| AgentBackendRunPausedInternalEvent
),
inputs: dict[str, Any],
process_data: dict[str, Any],
metadata: dict[str, Any],
) -> NodeRunResult:
status = WorkflowNodeExecutionStatus.FAILED
error = "Agent backend run failed."
error_type = "agent_backend_run_failed"
terminal_status = "failed"
match event:
case AgentBackendRunFailedInternalEvent():
error = event.error
error_type = event.reason or "agent_backend_run_failed"
terminal_status = "failed"
case AgentBackendRunCancelledInternalEvent():
error = event.message or "Agent backend run was cancelled."
error_type = "agent_backend_run_cancelled"
terminal_status = "cancelled"
case AgentBackendRunPausedInternalEvent():
error = event.message or "Agent backend run paused, but workflow Agent Node pause is not supported yet."
error_type = "agent_backend_paused_unsupported"
terminal_status = "paused"
metadata = self._with_terminal_metadata(metadata, event, terminal_status)
usage = self._usage_from_metadata(metadata)
return NodeRunResult(
status=status,
inputs=inputs,
process_data=process_data,
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
llm_usage=usage or LLMUsage.empty_usage(),
error=error,
error_type=error_type,
)
def build_stream_exhausted_result(
self,
*,
inputs: dict[str, Any],
process_data: dict[str, Any],
metadata: dict[str, Any],
) -> NodeRunResult:
usage = self._usage_from_metadata(metadata)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
llm_usage=usage or LLMUsage.empty_usage(),
error="Agent backend stream ended before a terminal event.",
error_type="agent_backend_stream_error",
)
@classmethod
def _normalize_outputs(cls, output: Any) -> dict[str, Any]:
if isinstance(output, dict):
if cls._is_file_payload(output):
return {"file": cls._file_segment_from_payload(output)}
return {key: cls._normalize_output_value(value) for key, value in output.items()}
if isinstance(output, str):
return {"text": output}
return {"result": output}
@classmethod
def _normalize_output_value(cls, value: Any) -> Any:
if isinstance(value, File | FileSegment | ArrayFileSegment):
return value
if isinstance(value, Mapping):
if cls._is_file_payload(value):
return cls._file_segment_from_payload(value)
return {key: cls._normalize_output_value(item) for key, item in value.items()}
if isinstance(value, list):
if value and all(isinstance(item, Mapping) and cls._is_file_payload(item) for item in value):
return ArrayFileSegment(value=[cls._file_from_payload(item) for item in value])
return [cls._normalize_output_value(item) for item in value]
return value
@staticmethod
def _is_file_payload(value: Mapping[str, Any]) -> bool:
return any(value.get(key) for key in ("file_id", "upload_file_id", "tool_file_id", "url", "remote_url")) and (
"filename" in value or "mime_type" in value or "url" in value or "remote_url" in value
)
@classmethod
def _file_segment_from_payload(cls, value: Mapping[str, Any]) -> FileSegment:
return FileSegment(value=cls._file_from_payload(value))
@classmethod
def _file_from_payload(cls, value: Mapping[str, Any]) -> File:
remote_url = cls._string_value(value.get("remote_url") or value.get("url"))
upload_file_id = cls._string_value(value.get("upload_file_id") or value.get("file_id"))
tool_file_id = cls._string_value(value.get("tool_file_id"))
filename = cls._string_value(value.get("filename") or value.get("name"))
mime_type = cls._string_value(value.get("mime_type") or value.get("mimetype"))
extension = cls._extension_from_payload(value, filename)
file_type = cls._file_type_from_payload(value, mime_type)
size = value.get("size")
if not isinstance(size, int):
size = -1
if tool_file_id:
transfer_method = FileTransferMethod.TOOL_FILE
related_id = tool_file_id
elif remote_url:
transfer_method = FileTransferMethod.REMOTE_URL
related_id = None
else:
transfer_method = FileTransferMethod.LOCAL_FILE
related_id = upload_file_id
return File(
type=file_type,
transfer_method=transfer_method,
remote_url=remote_url if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
@staticmethod
def _string_value(value: Any) -> str | None:
return value if isinstance(value, str) and value else None
@classmethod
def _extension_from_payload(cls, value: Mapping[str, Any], filename: str | None) -> str | None:
extension = cls._string_value(value.get("extension"))
if extension:
return extension if extension.startswith(".") else f".{extension}"
if filename and "." in filename:
return f".{filename.rsplit('.', 1)[1]}"
return None
@staticmethod
def _file_type_from_payload(value: Mapping[str, Any], mime_type: str | None) -> FileType:
explicit_type = value.get("type") or value.get("file_type")
if isinstance(explicit_type, str):
try:
return FileType(explicit_type)
except ValueError:
pass
if mime_type:
if mime_type.startswith("image/"):
return FileType.IMAGE
if mime_type.startswith("audio/"):
return FileType.AUDIO
if mime_type.startswith("video/"):
return FileType.VIDEO
return FileType.DOCUMENT
return FileType.CUSTOM
@staticmethod
def _usage_from_metadata(metadata: Mapping[str, Any]) -> LLMUsage | None:
agent_backend = metadata.get("agent_backend")
if not isinstance(agent_backend, Mapping):
return None
usage = agent_backend.get("usage")
if not isinstance(usage, Mapping):
return None
try:
return LLMUsage.from_metadata(usage)
except (TypeError, ValueError):
return None
@staticmethod
def _build_node_metadata(
*,
metadata: dict[str, Any],
usage: LLMUsage | None,
) -> dict[WorkflowNodeExecutionMetadataKey, Any]:
node_metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.AGENT_LOG: metadata,
}
if usage is not None:
node_metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
node_metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
node_metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
return node_metadata
@staticmethod
def _with_terminal_metadata(
metadata: dict[str, Any],
event: AgentBackendInternalEvent,
terminal_status: str,
) -> dict[str, Any]:
updated = dict(metadata)
agent_backend = dict(updated.get("agent_backend") or {})
agent_backend.update(
{
"run_id": event.run_id,
"terminal_event_id": event.source_event_id,
"status": terminal_status,
}
)
session_snapshot = None
if isinstance(event, AgentBackendRunSucceededInternalEvent | AgentBackendRunPausedInternalEvent):
session_snapshot = event.session_snapshot
if session_snapshot is not None:
agent_backend["session_snapshot"] = {
"layer_count": len(session_snapshot.layers),
}
updated["agent_backend"] = agent_backend
updated["terminal_event_type"] = AgentBackendInternalEventType(event.type).value
return updated

View File

@ -0,0 +1,55 @@
from __future__ import annotations
from typing import Any
from models.agent_config_entities import AgentSoulConfig
SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
{
"system_prompt",
"workflow_prompt",
"workflow_context",
"model",
"structured_output",
}
)
RESERVED_AGENT_BACKEND_FEATURES = frozenset(
{
"skills_files",
"tools",
"knowledge",
"human",
"env",
"sandbox",
"memory",
}
)
def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any]:
"""Describe PRD capabilities that are persisted but not executed in phase 3."""
warnings: list[dict[str, str]] = []
soul_dump = agent_soul.model_dump(mode="json")
for section in sorted(RESERVED_AGENT_BACKEND_FEATURES):
value = soul_dump.get(section)
has_value = bool(value)
if isinstance(value, dict):
has_value = any(bool(item) for item in value.values())
if has_value:
warnings.append(
{
"section": f"agent_soul.{section}",
"code": "agent_backend_layer_not_available",
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend in phase 3.",
}
)
reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed")
return {
"supported": sorted(SUPPORTED_AGENT_BACKEND_FEATURES),
"reserved": sorted(RESERVED_AGENT_BACKEND_FEATURES),
"reserved_status": reserved_status,
"unsupported_runtime_warnings": warnings,
}

View File

@ -0,0 +1,288 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, Protocol, cast
from dify_agent.protocol import CreateRunRequest, ExecutionContext
from clients.agent_backend import (
AgentBackendModelConfig,
AgentBackendOutputConfig,
AgentBackendRunRequestBuilder,
AgentBackendWorkflowNodeRunInput,
redact_for_agent_backend_log,
)
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom
from core.workflow.system_variables import SystemVariableKey, get_system_text
from graphon.variables.segments import Segment
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
from models.agent_config_entities import (
AgentSoulConfig,
DeclaredOutputConfig,
DeclaredOutputType,
WorkflowNodeJobConfig,
)
from .runtime_feature_manifest import build_runtime_feature_manifest
class WorkflowAgentRuntimeRequestBuildError(ValueError):
"""Raised when workflow state cannot be mapped to a valid Agent backend run request."""
def __init__(self, error_code: str, message: str) -> None:
self.error_code = error_code
super().__init__(message)
class VariablePoolReader(Protocol):
def get(self, selector: Sequence[str], /) -> Segment | None: ...
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ...
class CredentialsProvider(Protocol):
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
@dataclass(frozen=True, slots=True)
class WorkflowAgentRuntimeBuildContext:
dify_context: DifyRunContext
workflow_id: str
workflow_run_id: str | None
node_id: str
node_execution_id: str
variable_pool: VariablePoolReader
binding: WorkflowAgentNodeBinding
agent: Agent
snapshot: AgentConfigSnapshot
@dataclass(frozen=True, slots=True)
class WorkflowAgentRuntimeRequest:
request: CreateRunRequest
redacted_request: dict[str, Any]
agent_soul: AgentSoulConfig
node_job: WorkflowNodeJobConfig
metadata: dict[str, Any]
class WorkflowAgentRuntimeRequestBuilder:
"""Build public Dify Agent run requests from workflow Agent v2 runtime state."""
def __init__(
self,
*,
credentials_provider: CredentialsProvider,
request_builder: AgentBackendRunRequestBuilder | None = None,
) -> None:
self._credentials_provider = credentials_provider
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
def build(self, context: WorkflowAgentRuntimeBuildContext) -> WorkflowAgentRuntimeRequest:
agent_soul = AgentSoulConfig.model_validate(context.snapshot.config_snapshot_dict)
node_job = WorkflowNodeJobConfig.model_validate(context.binding.node_job_config_dict)
if agent_soul.model is None:
raise WorkflowAgentRuntimeRequestBuildError(
"agent_model_not_configured",
"Workflow Agent node requires Agent Soul model config.",
)
metadata = self._build_metadata(context, agent_soul, node_job)
workflow_context_prompt = self._build_workflow_context_prompt(context, node_job)
workflow_job_prompt = node_job.workflow_prompt.strip() or "Run this workflow Agent Node for the current run."
user_prompt = workflow_context_prompt.strip() or "Use the current workflow context."
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
request = self._request_builder.build_for_workflow_node(
AgentBackendWorkflowNodeRunInput(
model=AgentBackendModelConfig(
tenant_id=context.dify_context.tenant_id,
plugin_id=agent_soul.model.plugin_id,
model_provider=agent_soul.model.model_provider,
model=agent_soul.model.model,
user_id=context.dify_context.user_id,
credentials=self._normalize_credentials(credentials),
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
),
execution_context=ExecutionContext(
tenant_id=context.dify_context.tenant_id,
app_id=context.dify_context.app_id,
workflow_id=context.workflow_id,
workflow_run_id=context.workflow_run_id,
node_id=context.node_id,
node_execution_id=context.node_execution_id,
conversation_id=get_system_text(context.variable_pool, SystemVariableKey.CONVERSATION_ID),
agent_id=context.agent.id,
agent_config_version_id=context.snapshot.id,
invoke_from=self._agent_backend_invoke_from(context.dify_context.invoke_from),
),
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
workflow_node_job_prompt=workflow_job_prompt,
user_prompt=user_prompt,
output=self._build_output_config(node_job.declared_outputs),
idempotency_key=self._idempotency_key(context),
metadata=metadata,
)
)
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
return WorkflowAgentRuntimeRequest(
request=request,
redacted_request=redacted,
agent_soul=agent_soul,
node_job=node_job,
metadata=metadata,
)
@staticmethod
def _agent_backend_invoke_from(invoke_from: InvokeFrom) -> Literal["workflow_run", "single_step"]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.VALIDATION}:
return "single_step"
return "workflow_run"
@staticmethod
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
if context.workflow_run_id:
return f"{context.workflow_run_id}:{context.node_execution_id}"
return context.node_execution_id
@staticmethod
def _build_metadata(
context: WorkflowAgentRuntimeBuildContext,
agent_soul: AgentSoulConfig,
node_job: WorkflowNodeJobConfig,
) -> dict[str, Any]:
return {
"tenant_id": context.dify_context.tenant_id,
"app_id": context.dify_context.app_id,
"workflow_id": context.workflow_id,
"workflow_run_id": context.workflow_run_id,
"node_id": context.node_id,
"node_execution_id": context.node_execution_id,
"agent_id": context.agent.id,
"agent_config_snapshot_id": context.snapshot.id,
"binding_id": context.binding.id,
"workflow_node_job_mode": node_job.mode.value,
"runtime_support": build_runtime_feature_manifest(agent_soul),
}
def _build_workflow_context_prompt(
self,
context: WorkflowAgentRuntimeBuildContext,
node_job: WorkflowNodeJobConfig,
) -> str:
lines = ["Workflow context loaded for this run:"]
query = get_system_text(context.variable_pool, SystemVariableKey.QUERY)
if query:
lines.append(f"- User query: {query}")
resolved_outputs = self._resolve_previous_node_outputs(
context.variable_pool,
node_job.previous_node_output_refs,
)
if resolved_outputs:
lines.append("- Previous node outputs:")
for item in resolved_outputs:
lines.append(f" - {item['label']}: {item['value']}")
lines.append("The above workflow context is run-specific. Do not treat it as Agent Soul or persistent memory.")
return "\n".join(lines)
def _resolve_previous_node_outputs(
self,
variable_pool: VariablePoolReader,
refs: Sequence[Mapping[str, Any]],
) -> list[dict[str, Any]]:
resolved: list[dict[str, Any]] = []
for ref in refs:
selector = self._selector_from_ref(ref)
if not selector:
raise WorkflowAgentRuntimeRequestBuildError(
"invalid_previous_node_output_ref",
"Workflow Agent node has invalid previous node output ref.",
)
segment = variable_pool.get(selector)
if segment is None:
raise WorkflowAgentRuntimeRequestBuildError(
"missing_previous_node_output",
f"Workflow Agent node cannot resolve previous node output {'.'.join(selector)}.",
)
value = getattr(segment, "value", None)
resolved.append(
{
"label": ".".join(selector),
"value": self._summarize_value(value),
}
)
return resolved
@staticmethod
def _selector_from_ref(ref: Mapping[str, Any]) -> list[str] | None:
for key in ("selector", "variable_selector", "value_selector"):
value = ref.get(key)
if isinstance(value, list) and all(isinstance(item, str) for item in value):
return value
node_id = ref.get("node_id")
output_name = ref.get("output") or ref.get("name") or ref.get("variable") or ref.get("key")
if isinstance(node_id, str) and isinstance(output_name, str):
return [node_id, output_name]
return None
@staticmethod
def _summarize_value(value: Any) -> str:
text = str(value)
if len(text) > 2000:
return text[:2000] + "...[truncated]"
return text
@staticmethod
def _build_output_config(declared_outputs: Sequence[DeclaredOutputConfig]) -> AgentBackendOutputConfig | None:
if not declared_outputs:
return None
properties: dict[str, Any] = {}
required: list[str] = []
for output in declared_outputs:
properties[output.name] = WorkflowAgentRuntimeRequestBuilder._schema_for_declared_output(output)
if output.required:
required.append(output.name)
schema: dict[str, Any] = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return AgentBackendOutputConfig(json_schema=schema)
@staticmethod
def _schema_for_declared_output(output: DeclaredOutputConfig) -> dict[str, Any]:
match output.type:
case DeclaredOutputType.STRING:
schema: dict[str, Any] = {"type": "string"}
case DeclaredOutputType.NUMBER:
schema = {"type": "number"}
case DeclaredOutputType.BOOLEAN:
schema = {"type": "boolean"}
case DeclaredOutputType.OBJECT:
schema = {"type": "object"}
case DeclaredOutputType.ARRAY:
schema = {"type": "array"}
case DeclaredOutputType.FILE:
schema = {
"type": "object",
"properties": {
"file_id": {"type": "string"},
"filename": {"type": "string"},
"mime_type": {"type": "string"},
"url": {"type": "string"},
},
}
if output.description:
schema["description"] = output.description
return schema
@staticmethod
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
normalized: dict[str, str | int | float | bool | None] = {}
for key, value in credentials.items():
if isinstance(value, str | int | float | bool) or value is None:
normalized[key] = value
else:
normalized[key] = str(value)
return normalized

View File

@ -0,0 +1,388 @@
from __future__ import annotations
from collections import defaultdict, deque
from collections.abc import Iterator, Mapping, Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from graphon.enums import BuiltinNodeTypes
from models.agent import Agent, AgentConfigSnapshot, AgentStatus, WorkflowAgentNodeBinding
from models.agent_config_entities import AgentSoulConfig, WorkflowNodeJobConfig
from models.model import UploadFile
from models.workflow import Workflow
from .entities import DifyAgentNodeData
class WorkflowAgentNodeValidationError(ValueError):
"""Raised when a Workflow Agent v2 node cannot be executed or published."""
class WorkflowAgentNodeValidator:
"""Validate Agent v2 workflow nodes against graph topology and persisted bindings."""
_LOCKED_AGENT_SOUL_KEYS = frozenset(
{
"agent_soul",
"soul",
"prompt",
"system_prompt",
"skills_files",
"skills",
"files",
"tools",
"dify_tools",
"cli_tools",
"knowledge",
"env",
"environment",
"sandbox",
"sandbox_provider",
"memory",
"memory_strategy",
"model",
"app_features",
"app_variables",
"misc_legacy",
}
)
_SUPPORTED_HUMAN_CONTACT_CHANNELS = frozenset({"email", "slack", "web_app", "webapp", "chat"})
@classmethod
def validate_draft_workflow(cls, *, session: Session, workflow: Workflow) -> None:
cls._validate_workflow(session=session, workflow=workflow, require_binding=False)
@classmethod
def validate_published_workflow(cls, *, session: Session, workflow: Workflow) -> None:
cls._validate_workflow(session=session, workflow=workflow, require_binding=True)
@classmethod
def _validate_workflow(cls, *, session: Session, workflow: Workflow, require_binding: bool) -> None:
graph = workflow.graph_dict
topology = _WorkflowGraphTopology.from_graph(graph)
for node_id, node_data in cls.iter_agent_v2_nodes(graph):
cls._validate_node_schema(node_id=node_id, node_data=node_data)
binding = cls._find_binding(
session=session,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
node_id=node_id,
)
if binding is None:
if require_binding:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {node_id} requires a binding before publishing."
)
continue
cls.validate_binding(session=session, binding=binding, topology=topology)
@classmethod
def validate_binding(
cls,
*,
session: Session,
binding: WorkflowAgentNodeBinding,
topology: _WorkflowGraphTopology | None = None,
) -> None:
if binding.agent_id is None:
raise WorkflowAgentNodeValidationError(f"Workflow Agent node {binding.node_id} is missing agent binding.")
if binding.current_snapshot_id is None:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} is missing config snapshot binding."
)
agent = session.scalar(
select(Agent)
.where(
Agent.tenant_id == binding.tenant_id,
Agent.id == binding.agent_id,
)
.limit(1)
)
if agent is None or agent.status == AgentStatus.ARCHIVED:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references an unavailable agent."
)
snapshot = session.scalar(
select(AgentConfigSnapshot)
.where(
AgentConfigSnapshot.tenant_id == binding.tenant_id,
AgentConfigSnapshot.agent_id == agent.id,
AgentConfigSnapshot.id == binding.current_snapshot_id,
)
.limit(1)
)
if snapshot is None:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references a missing config snapshot."
)
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
if agent_soul.model is None:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} requires Agent Soul model config."
)
node_job = WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict)
cls.validate_node_job(session=session, binding=binding, node_job=node_job, topology=topology)
@classmethod
def validate_node_job(
cls,
*,
session: Session,
binding: WorkflowAgentNodeBinding,
node_job: WorkflowNodeJobConfig,
topology: _WorkflowGraphTopology | None = None,
) -> None:
cls._validate_locked_agent_soul_not_overridden(binding=binding, node_job=node_job)
output_names: set[str] = set()
for output in node_job.declared_outputs:
if output.name in output_names:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} has duplicate output name {output.name}."
)
output_names.add(output.name)
for check in output.checks:
if check.benchmark_file_ref is not None:
cls._validate_file_ref(
session=session,
binding=binding,
file_ref=check.benchmark_file_ref,
ref_context=f"output {output.name} benchmark file",
)
for ref in node_job.previous_node_output_refs:
selector = cls.selector_from_ref(ref)
if selector is None:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} has invalid previous node output ref."
)
if topology is None:
continue
if len(selector) < 2:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} has incomplete previous node output ref."
)
source_node_id = selector[0]
if not topology.has_node(source_node_id):
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references missing previous node {source_node_id}."
)
if not topology.is_upstream(source_node_id=source_node_id, target_node_id=binding.node_id):
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references non-upstream previous node {source_node_id}."
)
for human_ref in node_job.human_contacts:
cls._validate_human_ref(binding=binding, human_ref=human_ref)
file_refs = node_job.metadata.get("file_refs")
if isinstance(file_refs, list):
for file_ref in file_refs:
if isinstance(file_ref, Mapping):
cls._validate_file_ref(
session=session,
binding=binding,
file_ref=file_ref,
ref_context="metadata file ref",
)
@staticmethod
def iter_agent_v2_nodes(graph_dict: Mapping[str, Any]) -> Iterator[tuple[str, Mapping[str, Any]]]:
nodes = graph_dict.get("nodes")
if not isinstance(nodes, list):
return
for node in nodes:
if not isinstance(node, Mapping):
continue
node_id = node.get("id")
node_data = node.get("data")
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
continue
if node_data.get("type") == BuiltinNodeTypes.AGENT and str(node_data.get("version")) == "2":
yield node_id, node_data
@staticmethod
def selector_from_ref(ref: Mapping[str, Any]) -> list[str] | None:
for key in ("selector", "variable_selector", "value_selector"):
value = ref.get(key)
if isinstance(value, list) and all(isinstance(item, str) for item in value):
return value
node_id = ref.get("node_id")
output_name = ref.get("output") or ref.get("name") or ref.get("variable") or ref.get("key")
if isinstance(node_id, str) and isinstance(output_name, str):
return [node_id, output_name]
return None
@staticmethod
def _validate_node_schema(*, node_id: str, node_data: Mapping[str, Any]) -> None:
try:
DifyAgentNodeData.model_validate(node_data)
except ValueError as exc:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {node_id} has invalid Agent v2 node schema: {exc}"
) from exc
@classmethod
def _validate_locked_agent_soul_not_overridden(
cls,
*,
binding: WorkflowAgentNodeBinding,
node_job: WorkflowNodeJobConfig,
) -> None:
forbidden_paths = cls._find_locked_agent_soul_paths(node_job.metadata)
if forbidden_paths:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} cannot override locked Agent Soul fields: "
f"{', '.join(sorted(forbidden_paths))}."
)
@classmethod
def _find_locked_agent_soul_paths(cls, value: Any, *, path: str = "metadata") -> set[str]:
if not isinstance(value, Mapping):
return set()
forbidden: set[str] = set()
for key, item in value.items():
key_text = str(key)
if key_text in cls._LOCKED_AGENT_SOUL_KEYS:
forbidden.add(f"{path}.{key_text}")
forbidden.update(cls._find_locked_agent_soul_paths(item, path=f"{path}.{key_text}"))
return forbidden
@classmethod
def _validate_human_ref(
cls,
*,
binding: WorkflowAgentNodeBinding,
human_ref: Mapping[str, Any],
) -> None:
contact_id = human_ref.get("contact_id") or human_ref.get("human_id") or human_ref.get("id")
if not isinstance(contact_id, str) or not contact_id:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} has invalid human contact ref."
)
tenant_id = human_ref.get("tenant_id")
if tenant_id is not None and tenant_id != binding.tenant_id:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references out-of-scope human contact {contact_id}."
)
channel = human_ref.get("channel") or human_ref.get("method") or human_ref.get("contact_method")
if channel is not None and channel not in cls._SUPPORTED_HUMAN_CONTACT_CHANNELS:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references unsupported human contact channel {channel}."
)
@staticmethod
def _validate_file_ref(
*,
session: Session,
binding: WorkflowAgentNodeBinding,
file_ref: Mapping[str, Any],
ref_context: str,
) -> None:
tenant_id = file_ref.get("tenant_id")
if tenant_id is not None and tenant_id != binding.tenant_id:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references out-of-scope {ref_context}."
)
upload_file_id = (
file_ref.get("upload_file_id") or file_ref.get("file_id") or file_ref.get("id") or file_ref.get("reference")
)
if upload_file_id is None and (file_ref.get("url") or file_ref.get("remote_url")):
return
if not isinstance(upload_file_id, str) or not upload_file_id:
raise WorkflowAgentNodeValidationError(f"Workflow Agent node {binding.node_id} has invalid {ref_context}.")
upload_file = session.scalar(
select(UploadFile)
.where(
UploadFile.tenant_id == binding.tenant_id,
UploadFile.id == upload_file_id,
)
.limit(1)
)
if upload_file is None:
raise WorkflowAgentNodeValidationError(
f"Workflow Agent node {binding.node_id} references missing or out-of-scope {ref_context}."
)
@staticmethod
def _find_binding(
*,
session: Session,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> WorkflowAgentNodeBinding | None:
return session.scalar(
select(WorkflowAgentNodeBinding)
.where(
WorkflowAgentNodeBinding.tenant_id == tenant_id,
WorkflowAgentNodeBinding.app_id == app_id,
WorkflowAgentNodeBinding.workflow_id == workflow_id,
WorkflowAgentNodeBinding.node_id == node_id,
)
.limit(1)
)
class _WorkflowGraphTopology:
def __init__(self, *, node_ids: set[str], incoming: Mapping[str, Sequence[str]]) -> None:
self._node_ids = node_ids
self._incoming = incoming
@classmethod
def from_graph(cls, graph: Mapping[str, Any]) -> _WorkflowGraphTopology:
node_ids = cls._node_ids_from_graph(graph)
incoming: dict[str, list[str]] = defaultdict(list)
edges = graph.get("edges")
if isinstance(edges, list):
for edge in edges:
if not isinstance(edge, Mapping):
continue
source = edge.get("source")
target = edge.get("target")
if isinstance(source, str) and isinstance(target, str):
incoming[target].append(source)
return cls(node_ids=node_ids, incoming=incoming)
def has_node(self, node_id: str) -> bool:
return node_id in self._node_ids
def is_upstream(self, *, source_node_id: str, target_node_id: str) -> bool:
if source_node_id == target_node_id:
return False
visited: set[str] = set()
queue: deque[str] = deque(self._incoming.get(target_node_id, ()))
while queue:
candidate = queue.popleft()
if candidate == source_node_id:
return True
if candidate in visited:
continue
visited.add(candidate)
queue.extend(self._incoming.get(candidate, ()))
return False
@staticmethod
def _node_ids_from_graph(graph: Mapping[str, Any]) -> set[str]:
node_ids: set[str] = set()
nodes = graph.get("nodes")
if not isinstance(nodes, list):
return node_ids
for node in nodes:
if not isinstance(node, Mapping):
continue
node_id = node.get("id")
if isinstance(node_id, str):
node_ids.add(node_id)
return node_ids

View File

@ -0,0 +1,664 @@
"""Lint Flask-RESTX response docs against statically visible response serializers.
This checker intentionally stays conservative. It only reports a hard schema
mismatch when both sides are statically known for the same 2xx status code:
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``
or ``Model.model_validate(...).model_dump()`` return.
Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing
response schemas, and returns with non-literal status codes are classified as
unknown so reviewers can triage them without blocking unrelated work. The one
intentional non-schema mismatch is a known body/schema on a no-body status such
as 204, 205, or 304.
"""
from __future__ import annotations
import argparse
import ast
import json
import sys
from collections import Counter, defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import asdict, dataclass, field
from http import HTTPStatus
from pathlib import Path
from typing import Any, Literal
HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"}
NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value}
DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web")
type Classification = Literal["valid", "mismatch", "unknown", "refactorable"]
type ActualKind = Literal[
"empty",
"model",
"model_dump_variable",
"none",
"raw_dict",
"raw_list",
"raw_value",
"unknown",
]
type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef
HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus}
HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus})
@dataclass(frozen=True)
class DocumentedResponse:
status: int
model: str | None
line: int
@dataclass(frozen=True)
class ActualResponse:
status: int | None
kind: ActualKind
model: str | None
line: int
@dataclass(frozen=True)
class ContractCheck:
classification: Classification
file: str
class_name: str
method: str
route: str
line: int
reason: str
documented: dict[int, str | None]
actual: list[ActualResponse]
@dataclass(frozen=True)
class ContractCheckContext:
"""Stable route-method context shared by every classification result."""
file: str
class_name: str
method: str
route: str
line: int
documented: dict[int, str | None]
def build(
self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse]
) -> ContractCheck:
return ContractCheck(
classification=classification,
file=self.file,
class_name=self.class_name,
method=self.method,
route=self.route,
line=self.line,
reason=reason,
documented=self.documented,
actual=list(actual_responses),
)
def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck:
return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual])
@dataclass
class VariableAssignmentSummary:
"""Track whether a local name is safe to treat as one specific response model."""
known_models: set[str] = field(default_factory=set)
has_unknown_assignment: bool = False
def add_known(self, model: str) -> None:
self.known_models.add(model)
def add_unknown(self) -> None:
self.has_unknown_assignment = True
def single_known_model(self) -> str | None:
if self.has_unknown_assignment or len(self.known_models) != 1:
return None
return next(iter(self.known_models))
def dotted_name(node: ast.AST) -> str | None:
match node:
case ast.Name():
return node.id
case ast.Attribute():
parent = dotted_name(node.value)
if parent:
return f"{parent}.{node.attr}"
return node.attr
return None
def leaf_name(node: ast.AST) -> str | None:
name = dotted_name(node)
if name is None:
return None
return name.rsplit(".", 1)[-1]
def int_constant(node: ast.AST | None) -> int | None:
if isinstance(node, ast.Constant) and isinstance(node.value, int):
return node.value
if isinstance(node, ast.Name):
return HTTP_STATUS_NAMES.get(node.id)
if isinstance(node, ast.Attribute):
return HTTP_STATUS_NAMES.get(node.attr)
return None
def string_constant(node: ast.AST | None) -> str | None:
if isinstance(node, ast.Constant) and isinstance(node.value, str):
return node.value
return None
def keyword_value(call: ast.Call, *names: str) -> ast.AST | None:
for keyword in call.keywords:
if keyword.arg in names:
return keyword.value
return None
def is_probable_model_name(name: str) -> bool:
return bool(name) and name[0].isupper()
def model_name_from_schema_expr(node: ast.AST | None) -> str | None:
if node is None:
return None
if isinstance(node, ast.Subscript):
value_name = dotted_name(node.value)
if value_name and value_name.endswith(".models"):
# register_response_schema_models stores schemas by model name; both
# ns.models[Model.__name__] and ns.models["Model"] appear in controllers.
key = node.slice
if isinstance(key, ast.Attribute) and key.attr == "__name__":
return leaf_name(key.value)
return string_constant(key)
if isinstance(node, ast.Call):
func_name = dotted_name(node.func)
if func_name and func_name.endswith(".model"):
return string_constant(node.args[0] if node.args else keyword_value(node, "name"))
if isinstance(node, ast.Name):
return node.id
return None
def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None:
if not isinstance(decorator, ast.Call):
return None
func_name = dotted_name(decorator.func)
if not func_name or not func_name.endswith(".response"):
return None
status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status")
status = int_constant(status_expr)
if status is None:
return None
schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None
schema_expr = keyword_value(decorator, "model", "schema") or schema_expr
return DocumentedResponse(
status=status,
model=model_name_from_schema_expr(schema_expr),
line=decorator.lineno,
)
def route_from_decorator(decorator: ast.expr) -> str | None:
if not isinstance(decorator, ast.Call):
return None
func_name = dotted_name(decorator.func)
if not func_name or not func_name.endswith(".route"):
return None
return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path"))
def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]:
return [route for decorator in decorators if (route := route_from_decorator(decorator))]
def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]:
docs: dict[int, DocumentedResponse] = {}
for decorator in decorators:
response = documented_response_from_decorator(decorator)
if response and 200 <= response.status < 300:
docs[response.status] = response
return docs
def model_name_from_model_validate_call(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate":
return leaf_name(node.func.value)
return None
def model_name_from_constructor_call(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id):
return node.func.id
return None
def model_name_from_model_dump(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump":
return None
dumped_value = node.func.value
if isinstance(dumped_value, ast.Call):
return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value)
return None
def model_name_from_model_value(node: ast.AST) -> str | None:
return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node)
def model_name_from_dump_response(node: ast.AST) -> str | None:
if not isinstance(node, ast.Call):
return None
func_name = dotted_name(node.func)
if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")):
return None
model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model")
if isinstance(model_expr, ast.Name):
return model_expr.id
return None
def actual_kind_from_expr(
expr: ast.AST | None, variable_models: dict[str, str] | None = None
) -> tuple[ActualKind, str | None]:
if expr is None:
return "none", None
dump_response_model = model_name_from_dump_response(expr)
if dump_response_model:
return "model", dump_response_model
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump":
dumped_value = expr.func.value
if isinstance(dumped_value, ast.Name) and variable_models:
# A variable dump can match today, but it bypasses dump_response and
# is easier to drift; keep it visible as refactorable.
model_name = variable_models.get(dumped_value.id)
if model_name:
return "model_dump_variable", model_name
model_dump_model = model_name_from_model_dump(expr)
if model_dump_model:
return "model", model_dump_model
if isinstance(expr, ast.Constant):
if expr.value is None:
return "none", None
if expr.value == "":
return "empty", None
return "raw_value", None
if isinstance(expr, ast.Dict):
return "raw_dict", None
if isinstance(expr, ast.List):
return "raw_list", None
return "unknown", None
def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse:
status: int | None = 200
body_expr = return_node.value
if isinstance(return_node.value, ast.Tuple) and return_node.value.elts:
body_expr = return_node.value.elts[0]
if len(return_node.value.elts) >= 2:
# Dynamic statuses are not safe to coerce to 200; classify them as unknown.
status = int_constant(return_node.value.elts[1])
kind, model = actual_kind_from_expr(body_expr, variable_models)
return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno)
def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
"""Yield method body nodes while ignoring nested function/class scopes."""
stack: list[ast.AST] = list(reversed(method.body))
while stack:
node = stack.pop()
yield node
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef):
continue
stack.extend(reversed(list(ast.iter_child_nodes(node))))
def target_names(target: ast.AST) -> Iterable[str]:
if isinstance(target, ast.Name):
yield target.id
elif isinstance(target, ast.Tuple | ast.List):
for item in target.elts:
yield from target_names(item)
def record_assignment(
assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None
) -> None:
for target in targets:
if model_name is None:
# Once a name receives an unknown value, later model_dump() calls on it
# are no longer a reliable signal for the returned schema.
assignments[target].add_unknown()
else:
assignments[target].add_known(model_name)
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
"""Infer local variables that are unambiguously assigned one response model."""
assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary)
for node in iter_method_nodes(method):
match node:
case ast.Assign(targets=targets, value=value):
record_assignment(
assignments,
(name for target in targets for name in target_names(target)),
model_name_from_model_value(value),
)
case ast.AnnAssign(target=target, value=value) if value is not None:
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target):
# Mutation and loop targets overwrite prior values with runtime-dependent data.
record_assignment(assignments, target_names(target), None)
case ast.With(items=items) | ast.AsyncWith(items=items):
for item in items:
if item.optional_vars is not None:
record_assignment(assignments, target_names(item.optional_vars), None)
case ast.ExceptHandler(name=name) if name:
assignments[name].add_unknown()
case ast.NamedExpr(target=target, value=value):
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None}
def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]:
"""Extract statically visible 2xx returns from one controller method.
The analysis is deliberately shallow and conservative:
1. Walk only the method's own body, skipping nested functions/classes.
2. Infer local variables that are assigned exactly one recognizable response
model, so ``response.model_dump()`` can still be connected to its schema.
3. Treat any later unknown assignment, mutation target, loop target, context
manager binding, or exception binding as invalidating that variable.
4. For each top-level return path, split Flask-style ``(body, status)``
tuples, classify the body expression, and keep non-literal statuses as
``None`` so the classifier reports them as unknown instead of assuming 200.
5. Drop non-2xx literal statuses, since response contracts here only compare
successful response schemas.
"""
variable_models = variable_model_assignments_for_method(method)
responses: list[ActualResponse] = []
for node in iter_method_nodes(method):
if isinstance(node, ast.Return):
responses.append(actual_response_from_return(node, variable_models))
return [response for response in responses if response.status is None or 200 <= response.status < 300]
def display_path(file_path: Path, repo_root: Path) -> str:
try:
return str(file_path.relative_to(repo_root))
except ValueError:
return str(file_path)
def classify_method(
*,
actual_responses: Sequence[ActualResponse],
class_name: str,
documented_responses: dict[int, DocumentedResponse],
file_path: Path,
method: MethodNode,
repo_root: Path,
route: str,
) -> ContractCheck:
documented_summary = {status: response.model for status, response in sorted(documented_responses.items())}
context = ContractCheckContext(
file=display_path(file_path, repo_root),
class_name=class_name,
method=method.name,
route=route,
line=method.lineno,
documented=documented_summary,
)
if not actual_responses:
return context.build("unknown", "no statically visible 2xx return", [])
unknown_reasons: list[str] = []
refactorable_reasons: list[str] = []
for actual in actual_responses:
if actual.status is None:
unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status")
continue
documented = documented_responses.get(actual.status)
if actual.status in NO_BODY_STATUSES:
# No-body statuses are contract violations even when the schema names
# would otherwise match, because clients should not expect a payload.
if documented is not None and documented.model is not None:
return context.mismatch(
f"status {actual.status} is a no-body response but documents {documented.model}",
documented,
actual,
)
if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}:
no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno)
return context.mismatch(
f"status {actual.status} is a no-body response but returns {actual.kind}",
no_body_doc,
actual,
)
if actual.kind == "unknown":
unknown_reasons.append(f"status {actual.status} returns unknown body expression")
continue
if documented is None:
unknown_reasons.append(f"status {actual.status} has no @response doc")
continue
if documented.model is None:
unknown_reasons.append(f"status {actual.status} response doc has no schema model")
continue
if actual.kind == "model_dump_variable" and actual.model is not None:
if documented.model != actual.model:
return context.mismatch(
f"status {actual.status} documents {documented.model} but returns {actual.model}",
documented,
actual,
)
# The schema matches, but this path still deserves cleanup because
# dump_response is the contract-aware serialization helper.
refactorable_reasons.append(
f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response"
)
continue
if actual.kind != "model" or actual.model is None:
unknown_reasons.append(f"status {actual.status} returns {actual.kind}")
continue
if documented.model != actual.model:
return context.mismatch(
f"status {actual.status} documents {documented.model} but returns {actual.model}",
documented,
actual,
)
if unknown_reasons:
# Unknown beats refactorable: if any return path is ambiguous, do not
# imply the endpoint is merely a cleanup candidate.
return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses)
if refactorable_reasons:
return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses)
return context.build(
"valid",
"documented response schema matches statically visible return schema",
actual_responses,
)
def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]:
for path in paths:
if path.is_file() and path.suffix == ".py":
yield path
elif path.is_dir():
yield from sorted(child for child in path.rglob("*.py") if child.is_file())
def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path))
checks: list[ContractCheck] = []
for node in module.body:
if not isinstance(node, ast.ClassDef):
continue
class_routes = routes_from_decorators(node.decorator_list)
class_documented = response_docs_from_decorators(node.decorator_list)
for item in node.body:
if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS:
continue
routes = routes_from_decorators(item.decorator_list) or class_routes
if not routes:
continue
documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)}
# Method-level @response decorators override class-level defaults for
# the same status code, matching Flask-RESTX's common controller style.
actual = actual_responses_for_method(item)
for route in routes:
checks.append(
classify_method(
actual_responses=actual,
class_name=node.name,
documented_responses=documented,
file_path=file_path,
method=item,
repo_root=repo_root,
route=route,
)
)
return checks
def as_jsonable(check: ContractCheck) -> dict[str, Any]:
data = asdict(check)
data["documented"] = {str(status): model for status, model in check.documented.items()}
return data
def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None:
counts = Counter(check.classification for check in checks)
sys.stdout.write(
"Response contract lint: "
f"{counts['valid']} valid, "
f"{counts['mismatch']} mismatch, "
f"{counts['refactorable']} refactorable, "
f"{counts['unknown']} unknown\n"
)
for classification in ("mismatch", "refactorable", "unknown", "valid"):
filtered = [check for check in checks if check.classification == classification]
if classification == "valid" and not include_valid:
continue
if not filtered:
continue
sys.stdout.write(f"\n{classification.upper()}:\n")
for check in filtered:
location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}"
sys.stdout.write(f"- {location}: {check.reason}\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"paths",
nargs="*",
help="Files or directories to lint. Defaults to Flask controller directories.",
)
parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.")
parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.")
parser.add_argument(
"--fail-on-mismatch",
action="store_true",
help="Treat mismatched response contracts as failures. By default this linter is report-only.",
)
parser.add_argument(
"--fail-on-unknown",
action="store_true",
help="Treat unknown route methods as failures. By default this linter is report-only.",
)
return parser.parse_args()
def main() -> int:
args = parse_args()
api_root = Path(__file__).resolve().parents[1]
repo_root = api_root.parent
raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS)
paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)]
checks: list[ContractCheck] = []
for file_path in iter_controller_files(paths):
checks.extend(checks_for_file(file_path.resolve(), repo_root))
checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method))
if args.json:
grouped = defaultdict(list)
for check in checks:
grouped[check.classification].append(as_jsonable(check))
sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n")
else:
print_text_report(checks, include_valid=bool(args.include_valid))
has_mismatch = any(check.classification == "mismatch" for check in checks)
has_unknown = any(check.classification == "unknown" for check in checks)
return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown))
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -457,14 +457,16 @@ def init_app(app: DifyApp):
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
return StreamsBroadcastChannel(
_pubsub_redis_client,
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
join_timeout_ms=join_timeout_ms,
)
return RedisBroadcastChannel(_pubsub_redis_client)
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
def redis_fallback[T](default_return: T | None = None): # type: ignore

View File

@ -53,25 +53,27 @@ def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
def _convert_value(value: Any) -> Any:
"""Recursively convert non-serializable values."""
if value is None:
return None
if isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, Segment):
# Convert Segment to its underlying value
return _convert_value(value.value)
if isinstance(value, File):
# Convert File to dict
return value.to_dict()
if isinstance(value, BaseModel):
# Convert Pydantic model to dict
return _convert_value(value.model_dump(mode="json"))
if isinstance(value, dict):
return {k: _convert_value(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_convert_value(item) for item in value]
# Fallback to string representation for unknown types
return str(value)
match value:
case None:
return None
case bool() | int() | float() | str():
return value
case Segment():
# Convert Segment to its underlying value
return _convert_value(value.value)
case File():
# Convert File to dict
return value.to_dict()
case BaseModel():
# Convert Pydantic model to dict
return _convert_value(value.model_dump(mode="json"))
case dict():
return {k: _convert_value(v) for k, v in value.items()}
case list() | tuple():
return [_convert_value(item) for item in value]
case _:
# Fallback to string representation for unknown types
return str(value)
try:
converted = _convert_value(obj)
@ -104,15 +106,15 @@ class DefaultNodeOTelParser:
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
node_type = node.node_type
if node_type == BuiltinNodeTypes.LLM:
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
elif node_type == BuiltinNodeTypes.TOOL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
else:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
match node.node_type:
case BuiltinNodeTypes.LLM:
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
case BuiltinNodeTypes.TOOL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
case _:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:

Some files were not shown because too many files have changed in this diff Show More