mirror of
https://github.com/langgenius/dify.git
synced 2026-05-23 18:38:26 +08:00
Compare commits
41 Commits
dependabot
...
codex/dify
| Author | SHA1 | Date | |
|---|---|---|---|
| 40d06bd476 | |||
| 13e024618a | |||
| 5adf995527 | |||
| 5d4def8298 | |||
| d5d0d2d96f | |||
| 2a0c098857 | |||
| 790ca72627 | |||
| 4d8b6c7dc0 | |||
| 473c945839 | |||
| a698c60b29 | |||
| 24bab5fb2a | |||
| 93b7a81071 | |||
| 157e6244dd | |||
| 964aaad7ed | |||
| 92181dbe09 | |||
| 30deef45d9 | |||
| ee28074390 | |||
| 1fb491337b | |||
| 82b0a03f5a | |||
| 6185016910 | |||
| b4f5f4869f | |||
| 7ecbed3b04 | |||
| 5b58defd62 | |||
| 73196de5e1 | |||
| ea5e487d3c | |||
| f19702f76c | |||
| 092c8bca81 | |||
| c50d504c44 | |||
| 1b4356b66a | |||
| 7f633622aa | |||
| 66f5ab4cfc | |||
| 0cf9597f52 | |||
| 60cd346fa6 | |||
| 56d4d54c16 | |||
| 9f9cb4d17e | |||
| 7d0d9019d8 | |||
| d646bcf257 | |||
| e3b45a48eb | |||
| 848c15a265 | |||
| be8627233d | |||
| 1fe8b7fb1d |
@ -63,8 +63,8 @@ jobs:
|
||||
id: render
|
||||
run: |
|
||||
comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \
|
||||
--base base_report.json \
|
||||
< pr_report.json)"
|
||||
--base "$GITHUB_WORKSPACE/base_report.json" \
|
||||
< "$GITHUB_WORKSPACE/pr_report.json")"
|
||||
|
||||
{
|
||||
echo "### Pyrefly Type Coverage"
|
||||
|
||||
4
.github/workflows/pyrefly-type-coverage.yml
vendored
4
.github/workflows/pyrefly-type-coverage.yml
vendored
@ -65,6 +65,9 @@ jobs:
|
||||
# Save structured data for the fork-PR comment workflow
|
||||
cp /tmp/pyrefly_report_pr.json pr_report.json
|
||||
cp /tmp/pyrefly_report_base.json base_report.json
|
||||
# Keep fork-PR comments correct while the trusted workflow_run job is
|
||||
# still using the default-branch renderer, which resolves --base from api/.
|
||||
cp /tmp/pyrefly_report_base.json api/base_report.json
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
@ -77,6 +80,7 @@ jobs:
|
||||
path: |
|
||||
pr_report.json
|
||||
base_report.json
|
||||
api/base_report.json
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with type coverage
|
||||
|
||||
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -47,6 +47,10 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --directory api --dev lint-imports
|
||||
|
||||
- name: Run Response Contract Linter
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch
|
||||
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: make type-check-core
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
run: vp test run --reporter=blob --reporter=minimal --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
11
Makefile
11
Makefile
@ -75,13 +75,19 @@ check:
|
||||
@echo "✅ Code check complete"
|
||||
|
||||
lint:
|
||||
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
|
||||
@echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..."
|
||||
@uv run --project api --dev ruff format ./api
|
||||
@uv run --project api --dev ruff check --fix ./api
|
||||
@$(MAKE) api-contract-lint
|
||||
@uv run --directory api --dev lint-imports
|
||||
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
api-contract-lint:
|
||||
@echo "🔎 Linting Flask response contracts..."
|
||||
@uv run --project api --dev python api/dev/lint_response_contracts.py
|
||||
@echo "✅ Response contract lint complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (pyrefly + mypy)..."
|
||||
@./dev/pyrefly-check-local $(PATH_TO_CHECK)
|
||||
@ -191,6 +197,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make api-contract-lint - Check Flask response docs against returned schemas"
|
||||
@echo " make type-check - Run type checks (pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@ -204,4 +211,4 @@ help:
|
||||
@echo " make build-push-all - Build and push all Docker images"
|
||||
|
||||
# Phony targets
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all
|
||||
.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all
|
||||
|
||||
@ -767,6 +767,7 @@ EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while use redis as event bus.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
EVENT_BUS_REDIS_USE_CLUSTERS=false
|
||||
EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS=2000
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
|
||||
@ -195,6 +195,7 @@ Before opening a PR / submitting:
|
||||
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
|
||||
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
|
||||
- Document non-obvious behaviour with concise docstrings and comments.
|
||||
- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload.
|
||||
- For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`.
|
||||
In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response
|
||||
DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`,
|
||||
|
||||
@ -49,6 +49,7 @@ class AgentBackendModelConfig(BaseModel):
|
||||
model: str
|
||||
user_id: str | None = None
|
||||
credentials: dict[str, DifyPluginCredentialValue] = Field(default_factory=dict)
|
||||
model_settings: dict[str, JsonValue] = Field(default_factory=dict)
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
@ -138,6 +139,7 @@ class AgentBackendRunRequestBuilder:
|
||||
model_provider=run_input.model.model_provider,
|
||||
model=run_input.model.model,
|
||||
credentials=run_input.model.credentials,
|
||||
model_settings=run_input.model.model_settings or None,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from configs.extra.agent_backend_config import AgentBackendConfig
|
||||
from configs.extra.archive_config import ArchiveStorageConfig
|
||||
from configs.extra.notion_config import NotionConfig
|
||||
from configs.extra.sentry_config import SentryConfig
|
||||
@ -5,6 +6,7 @@ from configs.extra.sentry_config import SentryConfig
|
||||
|
||||
class ExtraServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
AgentBackendConfig,
|
||||
ArchiveStorageConfig,
|
||||
NotionConfig,
|
||||
SentryConfig,
|
||||
|
||||
23
api/configs/extra/agent_backend_config.py
Normal file
23
api/configs/extra/agent_backend_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class AgentBackendConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for the Agent backend runtime integration.
|
||||
"""
|
||||
|
||||
AGENT_BACKEND_BASE_URL: str | None = Field(
|
||||
description="Base URL for the Dify Agent backend service.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_USE_FAKE: bool = Field(
|
||||
description="Use the deterministic in-process fake Agent backend client.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_BACKEND_FAKE_SCENARIO: str = Field(
|
||||
description="Scenario used by the fake Agent backend client.",
|
||||
default="success",
|
||||
)
|
||||
@ -2,6 +2,7 @@ from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic.types import NonNegativeInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -70,6 +71,24 @@ class RedisPubSubConfig(BaseSettings):
|
||||
default=600,
|
||||
)
|
||||
|
||||
PUBSUB_LISTENER_JOIN_TIMEOUT_MS: NonNegativeInt = Field(
|
||||
validation_alias=AliasChoices("EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS", "PUBSUB_LISTENER_JOIN_TIMEOUT_MS"),
|
||||
description=(
|
||||
"Maximum time (milliseconds) that ``Subscription.close()`` waits for its listener thread to "
|
||||
"finish before returning. Bounds the tail latency between a terminal event being delivered to "
|
||||
"an SSE client and the response stream actually closing.\n\n"
|
||||
"The listener thread blocks on a polling read (XREAD BLOCK for streams, get_message timeout "
|
||||
"for pubsub/sharded) with a fixed 1s window, so close() naturally has to wait up to ~1s for "
|
||||
"the thread to notice the subscription was closed. Setting this lower (e.g. 100) lets close() "
|
||||
"return promptly while the daemon listener thread cleans itself up on the next poll "
|
||||
"boundary - safe because the listener holds no critical state and exits within one poll "
|
||||
"window. Setting it higher (e.g. 5000) gives the listener more grace before close() gives up "
|
||||
"and logs a warning. Default 2000ms preserves the pre-change behaviour.\n\n"
|
||||
"Also accepts ENV: EVENT_BUS_LISTENER_JOIN_TIMEOUT_MS."
|
||||
),
|
||||
default=2000,
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
|
||||
@ -36,6 +36,24 @@ class FileInfo(BaseModel):
|
||||
size: int
|
||||
|
||||
|
||||
def decode_remote_url(url: str, query_string: bytes | str = b"") -> str:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
if isinstance(query_string, bytes):
|
||||
raw_query = query_string.decode()
|
||||
else:
|
||||
raw_query = query_string
|
||||
if not raw_query:
|
||||
return decoded_url
|
||||
|
||||
if decoded_url.endswith(("?", "&")):
|
||||
separator = ""
|
||||
elif urllib.parse.urlsplit(decoded_url).query:
|
||||
separator = "&"
|
||||
else:
|
||||
separator = "?"
|
||||
return f"{decoded_url}{separator}{raw_query}"
|
||||
|
||||
|
||||
def guess_file_info_from_response(response: httpx.Response):
|
||||
url = str(response.url)
|
||||
# Try to extract filename from URL
|
||||
|
||||
@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource):
|
||||
db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id))
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:resource_id>/api-keys")
|
||||
|
||||
@ -269,12 +269,12 @@ class AnnotationApi(Resource):
|
||||
"message": "annotation_ids are required if the parameter is provided.",
|
||||
}, 400
|
||||
|
||||
result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return result, 204
|
||||
AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids)
|
||||
return "", 204
|
||||
# If no annotation_ids are provided, handle clearing all annotations
|
||||
else:
|
||||
AppAnnotationService.clear_all_annotations(str(app_id))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/export")
|
||||
@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@edit_permission_required
|
||||
def delete(self, app_id: UUID, annotation_id: UUID):
|
||||
AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/annotations/batch-import")
|
||||
|
||||
@ -633,7 +633,7 @@ class AppApi(Resource):
|
||||
app_service = AppService()
|
||||
app_service.delete_app(app_model)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/copy")
|
||||
|
||||
@ -29,9 +29,6 @@ from fields.conversation_fields import (
|
||||
from fields.conversation_fields import (
|
||||
ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse,
|
||||
)
|
||||
from fields.conversation_fields import (
|
||||
ResultResponse,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now, parse_time_range
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
@ -77,7 +74,6 @@ register_schema_models(
|
||||
ConversationMessageDetailResponse,
|
||||
ConversationWithSummaryPaginationResponse,
|
||||
ConversationDetailResponse,
|
||||
ResultResponse,
|
||||
CompletionConversationQuery,
|
||||
ChatConversationQuery,
|
||||
)
|
||||
@ -194,7 +190,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/chat-conversations")
|
||||
@ -347,7 +343,7 @@ class ChatConversationDetailApi(Resource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
def _get_conversation(app_model, conversation_id):
|
||||
|
||||
@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource):
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
except Exception as e:
|
||||
raise BadRequest(str(e))
|
||||
|
||||
@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource):
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
|
||||
@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
from typing import Any, cast
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.fields import ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse
|
||||
from controllers.common.schema import get_or_create_model, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.apikey import ApiKeyItem, ApiKeyList
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
@ -30,26 +31,10 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_detail_kernel_fields, related_app_list
|
||||
from fields.dataset_fields import (
|
||||
content_fields,
|
||||
dataset_detail_fields,
|
||||
dataset_fields,
|
||||
dataset_query_detail_fields,
|
||||
dataset_retrieval_model_fields,
|
||||
doc_metadata_fields,
|
||||
external_knowledge_info_fields,
|
||||
external_retrieval_model_fields,
|
||||
file_info_fields,
|
||||
icon_info_fields,
|
||||
keyword_setting_fields,
|
||||
reranking_model_fields,
|
||||
tag_fields,
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from fields.document_fields import document_status_fields
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import build_icon_url, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.url_utils import normalize_api_base_url
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
@ -61,58 +46,6 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D
|
||||
|
||||
register_response_schema_models(console_ns, ApiBaseUrlResponse, SimpleResultResponse, UsageCheckResponse)
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
dataset_base_model = get_or_create_model("DatasetBase", dataset_fields)
|
||||
|
||||
tag_model = get_or_create_model("Tag", tag_fields)
|
||||
|
||||
keyword_setting_model = get_or_create_model("DatasetKeywordSetting", keyword_setting_fields)
|
||||
vector_setting_model = get_or_create_model("DatasetVectorSetting", vector_setting_fields)
|
||||
|
||||
weighted_score_fields_copy = weighted_score_fields.copy()
|
||||
weighted_score_fields_copy["keyword_setting"] = fields.Nested(keyword_setting_model)
|
||||
weighted_score_fields_copy["vector_setting"] = fields.Nested(vector_setting_model)
|
||||
weighted_score_model = get_or_create_model("DatasetWeightedScore", weighted_score_fields_copy)
|
||||
|
||||
reranking_model = get_or_create_model("DatasetRerankingModel", reranking_model_fields)
|
||||
|
||||
dataset_retrieval_model_fields_copy = dataset_retrieval_model_fields.copy()
|
||||
dataset_retrieval_model_fields_copy["reranking_model"] = fields.Nested(reranking_model)
|
||||
dataset_retrieval_model_fields_copy["weights"] = fields.Nested(weighted_score_model, allow_null=True)
|
||||
dataset_retrieval_model = get_or_create_model("DatasetRetrievalModel", dataset_retrieval_model_fields_copy)
|
||||
|
||||
external_knowledge_info_model = get_or_create_model("ExternalKnowledgeInfo", external_knowledge_info_fields)
|
||||
|
||||
external_retrieval_model = get_or_create_model("ExternalRetrievalModel", external_retrieval_model_fields)
|
||||
|
||||
doc_metadata_model = get_or_create_model("DatasetDocMetadata", doc_metadata_fields)
|
||||
|
||||
icon_info_model = get_or_create_model("DatasetIconInfo", icon_info_fields)
|
||||
|
||||
dataset_detail_fields_copy = dataset_detail_fields.copy()
|
||||
dataset_detail_fields_copy["retrieval_model_dict"] = fields.Nested(dataset_retrieval_model)
|
||||
dataset_detail_fields_copy["tags"] = fields.List(fields.Nested(tag_model))
|
||||
dataset_detail_fields_copy["external_knowledge_info"] = fields.Nested(external_knowledge_info_model)
|
||||
dataset_detail_fields_copy["external_retrieval_model"] = fields.Nested(external_retrieval_model, allow_null=True)
|
||||
dataset_detail_fields_copy["doc_metadata"] = fields.List(fields.Nested(doc_metadata_model))
|
||||
dataset_detail_fields_copy["icon_info"] = fields.Nested(icon_info_model)
|
||||
dataset_detail_model = get_or_create_model("DatasetDetail", dataset_detail_fields_copy)
|
||||
|
||||
file_info_model = get_or_create_model("DatasetFileInfo", file_info_fields)
|
||||
|
||||
content_fields_copy = content_fields.copy()
|
||||
content_fields_copy["file_info"] = fields.Nested(file_info_model, allow_null=True)
|
||||
content_model = get_or_create_model("DatasetContent", content_fields_copy)
|
||||
|
||||
dataset_query_detail_fields_copy = dataset_query_detail_fields.copy()
|
||||
dataset_query_detail_fields_copy["queries"] = fields.Nested(content_model)
|
||||
dataset_query_detail_model = get_or_create_model("DatasetQueryDetail", dataset_query_detail_fields_copy)
|
||||
|
||||
app_detail_kernel_model = get_or_create_model("AppDetailKernel", app_detail_kernel_fields)
|
||||
related_app_list_copy = related_app_list.copy()
|
||||
related_app_list_copy["data"] = fields.List(fields.Nested(app_detail_kernel_model))
|
||||
related_app_list_model = get_or_create_model("RelatedAppList", related_app_list_copy)
|
||||
|
||||
|
||||
def _validate_indexing_technique(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
@ -208,9 +141,165 @@ class ConsoleDatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetListItemResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str]
|
||||
|
||||
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetListItemResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
class DatasetQueryFileInfoResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetQueryContentResponse(ResponseModel):
|
||||
content_type: str
|
||||
content: str
|
||||
file_info: DatasetQueryFileInfoResponse | None = None
|
||||
|
||||
|
||||
class DatasetQueryDetailResponse(ResponseModel):
|
||||
id: str
|
||||
queries: list[DatasetQueryContentResponse]
|
||||
source: str
|
||||
source_app_id: str | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: int
|
||||
|
||||
@field_validator("created_at", mode="before")
|
||||
@classmethod
|
||||
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DatasetQueryListResponse(ResponseModel):
|
||||
data: list[DatasetQueryDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class RelatedAppResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
mode: str = Field(validation_alias="mode_compatible_with_agent")
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
icon_url: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_icon_url(self) -> "RelatedAppResponse":
|
||||
self.icon_url = self.icon_url or build_icon_url(self.icon_type, self.icon)
|
||||
return self
|
||||
|
||||
|
||||
class RelatedAppListResponse(ResponseModel):
|
||||
data: list[RelatedAppResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class DocumentStatusResponse(ResponseModel):
|
||||
id: str
|
||||
indexing_status: str
|
||||
processing_started_at: int | None
|
||||
parsing_completed_at: int | None
|
||||
cleaning_completed_at: int | None
|
||||
splitting_completed_at: int | None
|
||||
completed_at: int | None
|
||||
paused_at: int | None
|
||||
error: str | None
|
||||
stopped_at: int | None
|
||||
completed_segments: int | None = None
|
||||
total_segments: int | None = None
|
||||
|
||||
@field_validator(
|
||||
"processing_started_at",
|
||||
"parsing_completed_at",
|
||||
"cleaning_completed_at",
|
||||
"splitting_completed_at",
|
||||
"completed_at",
|
||||
"paused_at",
|
||||
"stopped_at",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
|
||||
return to_timestamp(value)
|
||||
|
||||
|
||||
class DocumentStatusListResponse(ResponseModel):
|
||||
data: list[DocumentStatusResponse]
|
||||
|
||||
|
||||
class ErrorDocsResponse(DocumentStatusListResponse):
|
||||
total: int
|
||||
|
||||
|
||||
class IndexingEstimatePreviewItemResponse(ResponseModel):
|
||||
content: str
|
||||
child_chunks: list[str] | None = None
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class IndexingEstimateQaPreviewItemResponse(ResponseModel):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class IndexingEstimateResponse(ResponseModel):
|
||||
total_segments: int
|
||||
preview: list[IndexingEstimatePreviewItemResponse]
|
||||
qa_preview: list[IndexingEstimateQaPreviewItemResponse] | None = None
|
||||
|
||||
|
||||
class RetrievalSettingResponse(ResponseModel):
|
||||
retrieval_method: list[str]
|
||||
|
||||
|
||||
class PartialMemberListResponse(ResponseModel):
|
||||
data: list[str]
|
||||
|
||||
|
||||
class AutoDisableLogsResponse(ResponseModel):
|
||||
document_ids: list[str]
|
||||
count: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
|
||||
)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetQueryListResponse,
|
||||
IndexingEstimateResponse,
|
||||
RelatedAppListResponse,
|
||||
DocumentStatusListResponse,
|
||||
ErrorDocsResponse,
|
||||
RetrievalSettingResponse,
|
||||
PartialMemberListResponse,
|
||||
AutoDisableLogsResponse,
|
||||
)
|
||||
|
||||
|
||||
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
|
||||
@ -293,17 +382,8 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
|
||||
class DatasetListApi(Resource):
|
||||
@console_ns.doc("get_datasets")
|
||||
@console_ns.doc(description="Get list of datasets")
|
||||
@console_ns.doc(
|
||||
params={
|
||||
"page": "Page number (default: 1)",
|
||||
"limit": "Number of items per page (default: 20)",
|
||||
"ids": "Filter by dataset IDs (list)",
|
||||
"keyword": "Search keyword",
|
||||
"tag_ids": "Filter by tag IDs (list)",
|
||||
"include_all": "Include all datasets (default: false)",
|
||||
}
|
||||
)
|
||||
@console_ns.response(200, "Datasets retrieved successfully")
|
||||
@console_ns.doc(params=query_params_from_model(ConsoleDatasetListQuery))
|
||||
@console_ns.response(200, "Datasets retrieved successfully", console_ns.models[DatasetListResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -342,7 +422,7 @@ class DatasetListApi(Resource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"]
|
||||
partial_members_map: dict[str, list[str]] = {}
|
||||
if dataset_ids:
|
||||
@ -379,12 +459,12 @@ class DatasetListApi(Resource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
|
||||
@console_ns.doc("create_dataset")
|
||||
@console_ns.doc(description="Create a new dataset")
|
||||
@console_ns.expect(console_ns.models[DatasetCreatePayload.__name__])
|
||||
@console_ns.response(201, "Dataset created successfully")
|
||||
@console_ns.response(201, "Dataset created successfully", console_ns.models[DatasetDetailResponse.__name__])
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -413,7 +493,7 @@ class DatasetListApi(Resource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
return dump_response(DatasetDetailResponse, dataset), 201
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -421,7 +501,11 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("get_dataset")
|
||||
@console_ns.doc(description="Get dataset details")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Dataset retrieved successfully", dataset_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -437,7 +521,7 @@ class DatasetApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
if dataset.embedding_model_provider:
|
||||
provider_id = ModelProviderID(dataset.embedding_model_provider)
|
||||
@ -470,7 +554,11 @@ class DatasetApi(Resource):
|
||||
@console_ns.doc("update_dataset")
|
||||
@console_ns.doc(description="Update dataset details")
|
||||
@console_ns.expect(console_ns.models[DatasetUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Dataset updated successfully", dataset_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
console_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -506,7 +594,7 @@ class DatasetApi(Resource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
@ -535,7 +623,7 @@ class DatasetApi(Resource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@ -567,7 +655,11 @@ class DatasetQueryApi(Resource):
|
||||
@console_ns.doc("get_dataset_queries")
|
||||
@console_ns.doc(description="Get dataset query history")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Query history retrieved successfully", dataset_query_detail_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Query history retrieved successfully",
|
||||
console_ns.models[DatasetQueryListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -589,20 +681,24 @@ class DatasetQueryApi(Resource):
|
||||
dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
|
||||
|
||||
response = {
|
||||
"data": marshal(dataset_queries, dataset_query_detail_model),
|
||||
"data": dataset_queries,
|
||||
"has_more": len(dataset_queries) == limit,
|
||||
"limit": limit,
|
||||
"total": total,
|
||||
"page": page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetQueryListResponse, response), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/indexing-estimate")
|
||||
class DatasetIndexingEstimateApi(Resource):
|
||||
@console_ns.doc("estimate_dataset_indexing")
|
||||
@console_ns.doc(description="Estimate dataset indexing cost")
|
||||
@console_ns.response(200, "Indexing estimate calculated successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing estimate calculated successfully",
|
||||
console_ns.models[IndexingEstimateResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -699,11 +795,14 @@ class DatasetRelatedAppListApi(Resource):
|
||||
@console_ns.doc("get_dataset_related_apps")
|
||||
@console_ns.doc(description="Get applications related to dataset")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Related apps retrieved successfully", related_app_list_model)
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Related apps retrieved successfully",
|
||||
console_ns.models[RelatedAppListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(related_app_list_model)
|
||||
def get(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -724,7 +823,7 @@ class DatasetRelatedAppListApi(Resource):
|
||||
if app_model:
|
||||
related_apps.append(app_model)
|
||||
|
||||
return {"data": related_apps, "total": len(related_apps)}, 200
|
||||
return dump_response(RelatedAppListResponse, {"data": related_apps, "total": len(related_apps)}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/indexing-status")
|
||||
@ -732,7 +831,11 @@ class DatasetIndexingStatusApi(Resource):
|
||||
@console_ns.doc("get_dataset_indexing_status")
|
||||
@console_ns.doc(description="Get dataset indexing status")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Indexing status retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Indexing status retrieved successfully",
|
||||
console_ns.models[DocumentStatusListResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@ -778,9 +881,8 @@ class DatasetIndexingStatusApi(Resource):
|
||||
"completed_segments": completed_segments,
|
||||
"total_segments": total_segments,
|
||||
}
|
||||
documents_status.append(marshal(document_dict, document_status_fields))
|
||||
data = {"data": documents_status}
|
||||
return data, 200
|
||||
documents_status.append(document_dict)
|
||||
return dump_response(DocumentStatusListResponse, {"data": documents_status}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/api-keys")
|
||||
@ -873,7 +975,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/api-keys/<string:status>")
|
||||
@ -907,13 +1009,18 @@ class DatasetApiBaseUrlApi(Resource):
|
||||
class DatasetRetrievalSettingApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting")
|
||||
@console_ns.doc(description="Get dataset retrieval settings")
|
||||
@console_ns.response(200, "Retrieval settings retrieved successfully")
|
||||
@console_ns.response(
|
||||
200, "Retrieval settings retrieved successfully", console_ns.models[RetrievalSettingResponse.__name__]
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=False)
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=False),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/retrieval-setting/<string:vector_type>")
|
||||
@ -921,12 +1028,19 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
@console_ns.doc("get_dataset_retrieval_setting_mock")
|
||||
@console_ns.doc(description="Get mock dataset retrieval settings by vector type")
|
||||
@console_ns.doc(params={"vector_type": "Vector store type"})
|
||||
@console_ns.response(200, "Mock retrieval settings retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Mock retrieval settings retrieved successfully",
|
||||
console_ns.models[RetrievalSettingResponse.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, vector_type):
|
||||
return _get_retrieval_methods_by_vector_type(vector_type, is_mock=True)
|
||||
return dump_response(
|
||||
RetrievalSettingResponse,
|
||||
_get_retrieval_methods_by_vector_type(vector_type, is_mock=True),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/error-docs")
|
||||
@ -934,7 +1048,7 @@ class DatasetErrorDocs(Resource):
|
||||
@console_ns.doc("get_dataset_error_docs")
|
||||
@console_ns.doc(description="Get dataset error documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Error documents retrieved successfully")
|
||||
@console_ns.response(200, "Error documents retrieved successfully", console_ns.models[ErrorDocsResponse.__name__])
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -946,7 +1060,7 @@ class DatasetErrorDocs(Resource):
|
||||
raise NotFound("Dataset not found.")
|
||||
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
|
||||
|
||||
return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
|
||||
return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/permission-part-users")
|
||||
@ -954,7 +1068,11 @@ class DatasetPermissionUserListApi(Resource):
|
||||
@console_ns.doc("get_dataset_permission_users")
|
||||
@console_ns.doc(description="Get dataset permission user list")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Permission users retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Permission users retrieved successfully",
|
||||
console_ns.models[PartialMemberListResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@ -973,9 +1091,7 @@ class DatasetPermissionUserListApi(Resource):
|
||||
|
||||
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
|
||||
return {
|
||||
"data": partial_members_list,
|
||||
}, 200
|
||||
return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/auto-disable-logs")
|
||||
@ -983,7 +1099,11 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
@console_ns.doc("get_dataset_auto_disable_logs")
|
||||
@console_ns.doc(description="Get dataset auto disable logs")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.response(200, "Auto disable logs retrieved successfully")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Auto disable logs retrieved successfully",
|
||||
console_ns.models[AutoDisableLogsResponse.__name__],
|
||||
)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -993,4 +1113,4 @@ class DatasetAutoDisableLogApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
|
||||
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
|
||||
|
||||
@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/init")
|
||||
@ -966,7 +966,7 @@ class DocumentApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot pause completed document.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
|
||||
@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Document is not in paused status.")
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/retry")
|
||||
@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource):
|
||||
# retry document
|
||||
DocumentService.retry_document(dataset_id, retry_documents)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segments(segment_ids, document, dataset)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>")
|
||||
@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource):
|
||||
SegmentService.delete_child_chunk(child_chunk, dataset)
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id)
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@ -22,7 +26,12 @@ from services.metadata_service import MetadataService
|
||||
register_schema_models(
|
||||
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
|
||||
)
|
||||
register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
console_ns,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -31,7 +40,7 @@ class DatasetMetadataCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataArgs.__name__])
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -44,18 +53,22 @@ class DatasetMetadataCreateApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return metadata, 201
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -64,7 +77,7 @@ class DatasetMetadataApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
@console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__])
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@ -79,7 +92,7 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name)
|
||||
return metadata, 200
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -96,7 +109,8 @@ class DatasetMetadataApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return {"result": "success"}, 204
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/metadata/built-in")
|
||||
@ -105,9 +119,14 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -116,7 +135,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(204, "Action completed successfully")
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -130,7 +149,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
# Frontend callers only await success and invalidate metadata caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -140,7 +160,10 @@ class DocumentMetadataEditApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@console_ns.expect(console_ns.models[MetadataOperationData.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@console_ns.response(
|
||||
204,
|
||||
"Documents metadata updated successfully",
|
||||
)
|
||||
def post(self, dataset_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -153,4 +176,5 @@ class DocumentMetadataEditApi(Resource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
# Frontend callers only await success and invalidate caches; no response body is consumed.
|
||||
return "", 204
|
||||
|
||||
@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource):
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
|
||||
@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource):
|
||||
db.session.delete(installed_app)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "message": "App uninstalled successfully"}, 204
|
||||
return "", 204
|
||||
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__])
|
||||
def patch(self, installed_app):
|
||||
|
||||
@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource):
|
||||
|
||||
SavedMessageService.delete(app_model, current_user, message_id)
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, cloud_utm_record, setup_required
|
||||
|
||||
register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel)
|
||||
register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel)
|
||||
|
||||
|
||||
@console_ns.route("/features")
|
||||
@ -28,7 +28,32 @@ class FeatureApi(Resource):
|
||||
"""Get feature configuration for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_features(current_tenant_id).model_dump()
|
||||
payload = FeatureService.get_features(
|
||||
current_tenant_id,
|
||||
exclude_vector_space=True,
|
||||
).model_dump()
|
||||
payload.pop("vector_space", None)
|
||||
return payload
|
||||
|
||||
|
||||
@console_ns.route("/features/vector-space")
|
||||
class FeatureVectorSpaceApi(Resource):
|
||||
@console_ns.doc("get_tenant_feature_vector_space")
|
||||
@console_ns.doc(description="Get vector-space usage and limit for current tenant")
|
||||
@console_ns.response(
|
||||
200,
|
||||
"Success",
|
||||
console_ns.models[LimitationModel.__name__],
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_utm_record
|
||||
def get(self):
|
||||
"""Get vector-space usage and limit for current tenant"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
return FeatureService.get_vector_space(current_tenant_id).model_dump()
|
||||
|
||||
|
||||
@console_ns.route("/system-features")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -34,7 +33,7 @@ class GetRemoteFileInfo(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__])
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
|
||||
@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/credentials/switch")
|
||||
|
||||
@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource):
|
||||
tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials")
|
||||
@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
credential_id=args.credential_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/model-providers/<path:provider>/models/credentials/switch")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from typing import Literal, override
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@ -76,11 +76,13 @@ def _enum_value(value):
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return _enum_value(obj.status)
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
@override
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
status = _enum_value(obj.status)
|
||||
if status == WorkflowExecutionStatus.PAUSED.value:
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_enum_models, register_response_schema_models, register_schema_models
|
||||
from controllers.common.schema import (
|
||||
query_params_from_model,
|
||||
register_enum_models,
|
||||
register_response_schema_models,
|
||||
register_schema_models,
|
||||
)
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@ -17,9 +21,10 @@ from controllers.service_api.wraps import (
|
||||
)
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import DataSetTag
|
||||
from fields.base import ResponseModel
|
||||
from fields.dataset_fields import DatasetDetailResponse
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response
|
||||
from libs.login import current_user
|
||||
from models.account import Account
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
@ -119,6 +124,21 @@ class TagUnbindingPayload(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class KnowledgeTagResponse(ResponseModel):
|
||||
model_config = ConfigDict(coerce_numbers_to_str=True)
|
||||
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
# TODO: The public Service API docs expose binding_count as string|null.
|
||||
# Keep matching the old RESTX fields.String coercion until that contract is intentionally migrated.
|
||||
binding_count: str | None = None
|
||||
|
||||
|
||||
class KnowledgeTagListResponse(RootModel[list[KnowledgeTagResponse]]):
|
||||
pass
|
||||
|
||||
|
||||
class DatasetListQuery(BaseModel):
|
||||
page: int = Field(default=1, description="Page number")
|
||||
limit: int = Field(default=20, description="Number of items per page")
|
||||
@ -127,6 +147,29 @@ class DatasetListQuery(BaseModel):
|
||||
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
|
||||
|
||||
|
||||
class DatasetDetailWithPartialMembersResponse(DatasetDetailResponse):
|
||||
partial_member_list: list[str] | None = None
|
||||
|
||||
|
||||
# todo: duplicate code, but the partial_member_list has different nullability
|
||||
class DatasetListResponse(ResponseModel):
|
||||
data: list[DatasetDetailResponse]
|
||||
has_more: bool
|
||||
limit: int
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class DatasetBoundTagResponse(ResponseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class DatasetBoundTagListResponse(ResponseModel):
|
||||
data: list[DatasetBoundTagResponse]
|
||||
total: int
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
@ -137,9 +180,17 @@ register_schema_models(
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
DatasetListQuery,
|
||||
DataSetTag,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
SimpleResultResponse,
|
||||
KnowledgeTagResponse,
|
||||
KnowledgeTagListResponse,
|
||||
DatasetDetailResponse,
|
||||
DatasetDetailWithPartialMembersResponse,
|
||||
DatasetListResponse,
|
||||
DatasetBoundTagListResponse,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets")
|
||||
@ -154,9 +205,18 @@ class DatasetListApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.doc(params=query_params_from_model(DatasetListQuery))
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Datasets retrieved successfully",
|
||||
service_api_ns.models[DatasetListResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id):
|
||||
"""Resource for getting datasets."""
|
||||
query = DatasetListQuery.model_validate(request.args.to_dict())
|
||||
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
|
||||
if "tag_ids" in request.args:
|
||||
query_params["tag_ids"] = request.args.getlist("tag_ids")
|
||||
query = DatasetListQuery.model_validate(query_params)
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
@ -175,22 +235,17 @@ class DatasetListApi(DatasetApiResource):
|
||||
for embedding_model in embedding_models:
|
||||
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
|
||||
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
data = [dump_response(DatasetDetailResponse, dataset) for dataset in datasets]
|
||||
for item in data:
|
||||
if (
|
||||
item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY # pyrefly: ignore[bad-index]
|
||||
and item["embedding_model_provider"] # pyrefly: ignore[bad-index]
|
||||
):
|
||||
item["embedding_model_provider"] = str( # pyrefly: ignore[unsupported-operation]
|
||||
ModelProviderID(item["embedding_model_provider"]) # pyrefly: ignore[bad-index]
|
||||
)
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # pyrefly: ignore[bad-index]
|
||||
if item["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY and item["embedding_model_provider"]:
|
||||
item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"]))
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item["embedding_available"] = True # type: ignore
|
||||
item["embedding_available"] = True
|
||||
else:
|
||||
item["embedding_available"] = False # type: ignore
|
||||
item["embedding_available"] = False
|
||||
else:
|
||||
item["embedding_available"] = True # type: ignore
|
||||
item["embedding_available"] = True
|
||||
response = {
|
||||
"data": data,
|
||||
"has_more": len(datasets) == query.limit,
|
||||
@ -198,7 +253,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
"total": total,
|
||||
"page": query.page,
|
||||
}
|
||||
return response, 200
|
||||
return dump_response(DatasetListResponse, response), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@ -210,6 +265,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
400: "Bad request - invalid parameters",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset created successfully",
|
||||
service_api_ns.models[DatasetDetailResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
@ -253,7 +313,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
return dump_response(DatasetDetailResponse, dataset), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>")
|
||||
@ -271,6 +331,11 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset retrieved successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
def get(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -280,7 +345,7 @@ class DatasetApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
data = dump_response(DatasetDetailResponse, dataset)
|
||||
# check embedding setting
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
@ -312,7 +377,13 @@ class DatasetApi(DatasetApiResource):
|
||||
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
data.update({"partial_member_list": part_users_list})
|
||||
|
||||
return data, 200
|
||||
return (
|
||||
DatasetDetailWithPartialMembersResponse.model_validate(data).model_dump(
|
||||
mode="json",
|
||||
exclude={"partial_member_list"} if "partial_member_list" not in data else set(),
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@ -326,6 +397,11 @@ class DatasetApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Dataset updated successfully",
|
||||
service_api_ns.models[DatasetDetailWithPartialMembersResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
@ -376,7 +452,7 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
|
||||
result_data = dump_response(DatasetDetailResponse, dataset)
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@ -389,7 +465,7 @@ class DatasetApi(DatasetApiResource):
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
result_data.update({"partial_member_list": partial_member_list})
|
||||
|
||||
return result_data, 200
|
||||
return DatasetDetailWithPartialMembersResponse.model_validate(result_data).model_dump(mode="json"), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset")
|
||||
@service_api_ns.doc(description="Delete a dataset")
|
||||
@ -502,7 +578,7 @@ class DocumentStatusApi(DatasetApiResource):
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(SimpleResultResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags")
|
||||
@ -515,14 +591,18 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[KnowledgeTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
assert cid is not None
|
||||
tags = TagService.get_tags("knowledge", cid)
|
||||
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
|
||||
return [tag.model_dump(mode="json") for tag in tag_models], 200
|
||||
return dump_response(KnowledgeTagListResponse, tags), 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@ -534,6 +614,11 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag created successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
@ -543,9 +628,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE))
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
).model_dump(mode="json")
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0},
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@ -558,6 +644,11 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tag updated successfully",
|
||||
service_api_ns.models[KnowledgeTagResponse.__name__],
|
||||
)
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -569,9 +660,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
response = DataSetTag.model_validate(
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
).model_dump(mode="json")
|
||||
response = dump_response(
|
||||
KnowledgeTagResponse,
|
||||
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count},
|
||||
)
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@ -656,6 +748,11 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Tags retrieved successfully",
|
||||
service_api_ns.models[DatasetBoundTagListResponse.__name__],
|
||||
)
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
@ -663,5 +760,4 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
return dump_response(DatasetBoundTagListResponse, {"data": tags_list, "total": len(tags)}), 200
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MetadataUpdatePayload
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from fields.dataset_fields import (
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
from libs.helper import dump_response
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
@ -27,7 +31,13 @@ register_schema_models(
|
||||
DocumentMetadataOperation,
|
||||
MetadataOperationData,
|
||||
)
|
||||
register_response_schema_models(service_api_ns, SimpleResultResponse)
|
||||
register_response_schema_models(
|
||||
service_api_ns,
|
||||
DatasetMetadataActionResponse,
|
||||
DatasetMetadataBuiltInFieldsResponse,
|
||||
DatasetMetadataListResponse,
|
||||
DatasetMetadataResponse,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
@ -43,6 +53,9 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
@ -55,7 +68,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.create_metadata(dataset_id_str, metadata_args)
|
||||
return marshal(metadata, dataset_metadata_fields), 201
|
||||
return dump_response(DatasetMetadataResponse, metadata), 201
|
||||
|
||||
@service_api_ns.doc("get_dataset_metadata")
|
||||
@service_api_ns.doc(description="Get all metadata for a dataset")
|
||||
@ -67,13 +80,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
404: "Dataset not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__]
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all metadata for a dataset."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
return MetadataService.get_dataset_metadatas(dataset), 200
|
||||
metadata = MetadataService.get_dataset_metadatas(dataset)
|
||||
return dump_response(DatasetMetadataListResponse, metadata), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
@ -89,6 +106,9 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
@ -102,7 +122,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
return dump_response(DatasetMetadataResponse, metadata), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@service_api_ns.doc(description="Delete metadata")
|
||||
@ -114,6 +134,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
404: "Dataset or metadata not found",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(204, "Metadata deleted successfully")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Delete metadata."""
|
||||
@ -138,10 +159,15 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Built-in fields retrieved successfully",
|
||||
service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__],
|
||||
)
|
||||
def get(self, tenant_id, dataset_id):
|
||||
"""Get all built-in metadata fields."""
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>")
|
||||
@ -157,9 +183,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Action completed successfully",
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__]
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]):
|
||||
@ -175,7 +199,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
case "disable":
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
@ -194,7 +218,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.response(
|
||||
200,
|
||||
"Documents metadata updated successfully",
|
||||
service_api_ns.models[SimpleResultResponse.__name__],
|
||||
service_api_ns.models[DatasetMetadataActionResponse.__name__],
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
@ -209,4 +233,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200
|
||||
|
||||
@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse)
|
||||
@service_api_ns.route("/")
|
||||
class IndexApi(Resource):
|
||||
@service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__])
|
||||
def get(self):
|
||||
def get(self) -> dict[str, str]:
|
||||
return {
|
||||
"welcome": "Dify OpenAPI",
|
||||
"api_version": "v1",
|
||||
|
||||
@ -136,7 +136,7 @@ class ConversationApi(WebApiResource):
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@web_ns.route("/conversations/<uuid:c_id>/name")
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
import services
|
||||
@ -59,7 +58,7 @@ class RemoteFileInfoApi(WebApiResource):
|
||||
Raises:
|
||||
HTTPException: If the remote file cannot be accessed
|
||||
"""
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
decoded_url = helpers.decode_remote_url(url, request.query_string)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# failed back to get method
|
||||
|
||||
@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource):
|
||||
|
||||
SavedMessageService.delete(app_model, end_user, message_id)
|
||||
|
||||
return ResultResponse(result="success").model_dump(mode="json"), 204
|
||||
return "", 204
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.file import file_manager
|
||||
@ -66,6 +67,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import override
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -51,6 +52,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
return historic_prompt
|
||||
|
||||
@override
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
@override
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
@override
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
|
||||
@ -55,6 +55,7 @@ from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVarLoader,
|
||||
WorkflowDraftVariableService,
|
||||
@ -145,9 +146,15 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
try:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
conversation = None
|
||||
else:
|
||||
raise
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, Any, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from pydantic import JsonValue
|
||||
|
||||
@ -16,6 +16,7 @@ from core.app.entities.task_entities import (
|
||||
|
||||
class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking full response.
|
||||
@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse):
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from graphon.enums import NodeType
|
||||
|
||||
@ -29,5 +29,6 @@ class DraftVariableSaverFactory(Protocol):
|
||||
|
||||
|
||||
class NoopDraftVariableSaver(DraftVariableSaver):
|
||||
@override
|
||||
def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None:
|
||||
return None
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self._app_mode = app_mode
|
||||
self._message_id = str(message_id)
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
@override
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import Any, cast, override
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse]
|
||||
):
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_full_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_blocking_simple_response(
|
||||
cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse
|
||||
) -> dict[str, Any]:
|
||||
@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
@ -73,6 +76,7 @@ class WorkflowAppGenerateResponseConverter(
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict[str, Any] | str, None, None]:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import override
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -31,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
) -> None:
|
||||
self._scope_getter = scope_getter
|
||||
|
||||
@override
|
||||
def current_scope(self) -> FileAccessScope | None:
|
||||
return self._scope_getter()
|
||||
|
||||
@override
|
||||
def apply_upload_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[UploadFile]],
|
||||
@ -62,6 +65,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def apply_tool_file_filters(
|
||||
self,
|
||||
stmt: Select[tuple[ToolFile]],
|
||||
@ -78,6 +82,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
|
||||
return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id)
|
||||
|
||||
@override
|
||||
def get_upload_file(
|
||||
self,
|
||||
*,
|
||||
@ -95,6 +100,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol):
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
@override
|
||||
def get_tool_file(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -8,6 +8,7 @@ scope updates that matter to chat applications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID
|
||||
@ -23,9 +24,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._conversation_variable_updater = conversation_variable_updater
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if not isinstance(event, NodeRunVariableUpdatedEvent):
|
||||
return
|
||||
@ -44,5 +47,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self
|
||||
from typing import Annotated, Literal, Self, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Engine
|
||||
@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def _get_repo(self) -> APIWorkflowRunRepository:
|
||||
return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker)
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""
|
||||
Called when graph execution starts.
|
||||
@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
pause_reasons=event.reasons,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""
|
||||
Called when graph execution ends.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
|
||||
@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer):
|
||||
super().__init__()
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
self._paused = False
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Handle the paused event, stash runtime state into storage and wait for resume.
|
||||
@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer):
|
||||
if isinstance(event, GraphRunPausedEvent):
|
||||
self._paused = True
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None):
|
||||
""" """
|
||||
self._paused = False
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, override
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler # type: ignore
|
||||
|
||||
@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
except Exception:
|
||||
logger.exception("scheduler error during check if the workflow need to be suspended")
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
"""
|
||||
Start timer to check if the workflow need to be suspended.
|
||||
@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer):
|
||||
id=self.schedule_id,
|
||||
)
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self.stopped = True
|
||||
# remove the scheduler
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
self.start_time = start_time
|
||||
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
|
||||
|
||||
@override
|
||||
def on_graph_start(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent):
|
||||
"""
|
||||
Update trigger log with success or failure.
|
||||
@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer):
|
||||
repo.update(trigger_log)
|
||||
session.commit()
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
pass
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol
|
||||
@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
self._file_access_controller = file_access_controller
|
||||
|
||||
@property
|
||||
@override
|
||||
def multimodal_send_format(self) -> str:
|
||||
return dify_config.MULTIMODAL_SEND_FORMAT
|
||||
|
||||
@override
|
||||
def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol:
|
||||
return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects)
|
||||
|
||||
@override
|
||||
def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator:
|
||||
return storage.load(path, stream=stream)
|
||||
|
||||
@override
|
||||
def load_file_bytes(self, *, file: File) -> bytes:
|
||||
storage_key = self._resolve_storage_key(file=file)
|
||||
data = storage.load(storage_key, stream=False)
|
||||
@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
raise ValueError(f"file {storage_key} is not a bytes object")
|
||||
return data
|
||||
|
||||
@override
|
||||
def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
)
|
||||
return None
|
||||
|
||||
@override
|
||||
def resolve_upload_file_url(
|
||||
self,
|
||||
*,
|
||||
@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol):
|
||||
query["as_attachment"] = "true"
|
||||
return f"{url}?{urllib.parse.urlencode(query)}"
|
||||
|
||||
@override
|
||||
def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str:
|
||||
self._assert_tool_file_access(tool_file_id=tool_file_id)
|
||||
return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external)
|
||||
|
||||
@override
|
||||
def verify_preview_signature(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -12,7 +12,7 @@ state.
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, override
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.helper.trace_id_helper import ParentTraceContext
|
||||
@ -98,12 +98,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
# ------------------------------------------------------------------
|
||||
# GraphEngineLayer lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
self._workflow_execution = None
|
||||
self._node_execution_cache.clear()
|
||||
self._node_snapshots.clear()
|
||||
self._node_sequence = 0
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
match event:
|
||||
case GraphRunStartedEvent():
|
||||
@ -131,6 +133,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
case NodeRunPauseRequestedEvent():
|
||||
self._handle_node_pause_requested(event)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
return
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def get_icon_url(self, tenant_id: str) -> str:
|
||||
return self.icon
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@override
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -67,5 +68,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import override
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
@override
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
|
||||
@ -6,7 +6,7 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from sqlalchemy import func, select
|
||||
@ -1889,6 +1889,7 @@ class ProviderConfigurations(BaseModel):
|
||||
key = str(ModelProviderID(key))
|
||||
return key in self.configurations
|
||||
|
||||
@override
|
||||
def __iter__(self):
|
||||
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
||||
yield from self.configurations.items()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -29,6 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -50,6 +51,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
@override
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class JavascriptCodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.JAVASCRIPT
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class NodeJsTemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
_template_b64_placeholder: str = "{{template_b64}}"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def transform_response(cls, response: str):
|
||||
"""
|
||||
Transform response to dict
|
||||
@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return {"result": cls.extract_result_str_from_response(response)}
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Override base class to use base64 encoding for template code.
|
||||
@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f"""
|
||||
import jinja2
|
||||
@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer):
|
||||
return runner_script
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_preload_script(cls) -> str:
|
||||
preload_script = dedent("""
|
||||
import jinja2
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
|
||||
class Python3CodeProvider(CodeNodeProvider):
|
||||
@staticmethod
|
||||
@override
|
||||
def get_language() -> str:
|
||||
return CodeLanguage.PYTHON3
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def get_default_code(cls) -> str:
|
||||
return dedent(
|
||||
"""
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from textwrap import dedent
|
||||
from typing import override
|
||||
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
|
||||
class Python3TemplateTransformer(TemplateTransformer):
|
||||
@classmethod
|
||||
@override
|
||||
def get_runner_script(cls) -> str:
|
||||
runner_script = dedent(f""" {cls._code_placeholder}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
@override
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
|
||||
@ -43,13 +43,16 @@ request_error = httpx.RequestError
|
||||
max_retries_exceeded_error = MaxRetriesExceededError
|
||||
|
||||
|
||||
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
|
||||
def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]:
|
||||
"""Build per-scheme proxy transports with the same TLS policy as the SSRF client."""
|
||||
return {
|
||||
"http://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTP_URL,
|
||||
verify=verify,
|
||||
),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL,
|
||||
verify=verify,
|
||||
),
|
||||
}
|
||||
|
||||
@ -64,7 +67,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client:
|
||||
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
return httpx.Client(
|
||||
mounts=_create_proxy_mounts(),
|
||||
mounts=_create_proxy_mounts(verify=verify),
|
||||
verify=verify,
|
||||
limits=_SSRF_CLIENT_LIMITS,
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
import flask
|
||||
|
||||
@ -15,6 +16,7 @@ class TraceContextFilter(logging.Filter):
|
||||
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
@ -54,6 +56,7 @@ class IdentityContextFilter(logging.Filter):
|
||||
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
|
||||
"""
|
||||
|
||||
@override
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict, override
|
||||
|
||||
import orjson
|
||||
|
||||
@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
super().__init__()
|
||||
self._service_name = service_name or dify_config.APPLICATION_NAME
|
||||
|
||||
@override
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_dict = self._build_log_dict(record)
|
||||
try:
|
||||
|
||||
@ -7,7 +7,7 @@ authentication failures and retries operations after refreshing tokens.
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -159,6 +159,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
# Reset retry flag after operation completes
|
||||
self._has_retried = False
|
||||
|
||||
@override
|
||||
def __enter__(self):
|
||||
"""Enter the context manager with retry support."""
|
||||
|
||||
@ -168,6 +169,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
|
||||
return self._execute_with_retry(initialize_with_retry)
|
||||
|
||||
@override
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""
|
||||
List available tools from the MCP server with auth retry.
|
||||
@ -180,6 +182,7 @@ class MCPClientWithAuthRetry(MCPClient):
|
||||
"""
|
||||
return self._execute_with_retry(super().list_tools)
|
||||
|
||||
@override
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
|
||||
"""
|
||||
Invoke a tool on the MCP server with auth retry.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import queue
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
from typing import Any, Protocol, override
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
@ -159,6 +159,7 @@ class ClientSession(
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@override
|
||||
def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
@ -326,6 +327,7 @@ class ClientSession(
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
@ -351,6 +353,7 @@ class ClientSession(
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
@override
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
@ -358,6 +361,7 @@ class ClientSession(
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
@override
|
||||
def _received_notification(self, notification: types.ServerNotification):
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
@ -25,6 +25,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -43,6 +44,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -59,6 +61,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
|
||||
@ -8,6 +8,7 @@ class KeywordsModeration(Moderation):
|
||||
name: str = "keywords"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -28,6 +29,7 @@ class KeywordsModeration(Moderation):
|
||||
if len(keywords_row_len) > 100:
|
||||
raise ValueError("the number of rows for the keywords must be less than 100")
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -49,6 +51,7 @@ class KeywordsModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult
|
||||
@ -9,6 +9,7 @@ class OpenAIModeration(Moderation):
|
||||
name: str = "openai_moderation"
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
@ -19,6 +20,7 @@ class OpenAIModeration(Moderation):
|
||||
"""
|
||||
cls._validate_inputs_and_outputs_config(config, True)
|
||||
|
||||
@override
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
@ -36,6 +38,7 @@ class OpenAIModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
@override
|
||||
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import override
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
@ -11,6 +12,7 @@ class PluginDaemonError(Exception):
|
||||
def __init__(self, description: str):
|
||||
self.description = description
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"req_id: {get_request_id()} {self.__class__.__name__}: {self.description}"
|
||||
|
||||
@ -4,7 +4,7 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from threading import Lock
|
||||
from typing import IO, Any, Literal, cast, overload
|
||||
from typing import IO, Any, Literal, cast, overload, override
|
||||
|
||||
from pydantic import ValidationError
|
||||
from redis import RedisError
|
||||
@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
self._provider_entities = None
|
||||
self._provider_entities_lock = Lock()
|
||||
|
||||
@override
|
||||
def fetch_model_providers(self) -> Sequence[ProviderEntity]:
|
||||
if self._provider_entities is not None:
|
||||
return self._provider_entities
|
||||
@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
|
||||
return self._provider_entities
|
||||
|
||||
@override
|
||||
def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
|
||||
provider_schema = self._get_provider_schema(provider)
|
||||
|
||||
@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
mime_type = image_mime_types.get(extension, "image/png")
|
||||
return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type
|
||||
|
||||
@override
|
||||
def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None:
|
||||
plugin_id, provider_name = self._split_provider(provider)
|
||||
self.client.validate_provider_credentials(
|
||||
@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
*,
|
||||
@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_model_schema(
|
||||
self,
|
||||
*,
|
||||
@ -294,6 +299,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunk, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm(
|
||||
self,
|
||||
*,
|
||||
@ -357,6 +363,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream: Literal[True],
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
|
||||
@override
|
||||
def invoke_llm_with_structured_output(
|
||||
self,
|
||||
*,
|
||||
@ -396,6 +403,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -422,6 +430,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
tools=list(tools) if tools else None,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -443,6 +452,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
*,
|
||||
@ -464,6 +474,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
input_type=input_type,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
*,
|
||||
@ -483,6 +494,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -508,6 +520,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
*,
|
||||
@ -533,6 +546,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_tts(
|
||||
self,
|
||||
*,
|
||||
@ -554,6 +568,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
*,
|
||||
@ -573,6 +588,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
language=language,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
*,
|
||||
@ -592,6 +608,7 @@ class PluginModelRuntime(ModelRuntime):
|
||||
file=file,
|
||||
)
|
||||
|
||||
@override
|
||||
def invoke_moderation(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, override
|
||||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
@ -29,6 +29,7 @@ class Jieba(BaseKeyword):
|
||||
super().__init__(dataset)
|
||||
self._config = KeywordTableConfig()
|
||||
|
||||
@override
|
||||
def create(self, texts: list[Document], **kwargs) -> BaseKeyword:
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -48,6 +49,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return self
|
||||
|
||||
@override
|
||||
def add_texts(self, texts: list[Document], **kwargs):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -72,12 +74,14 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def text_exists(self, id: str) -> bool:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
if keyword_table is None:
|
||||
return False
|
||||
return id in set.union(*keyword_table.values())
|
||||
|
||||
@override
|
||||
def delete_by_ids(self, ids: list[str]):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
@ -87,6 +91,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
@override
|
||||
def search(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
@ -122,6 +127,7 @@ class Jieba(BaseKeyword):
|
||||
|
||||
return documents
|
||||
|
||||
@override
|
||||
def delete(self):
|
||||
lock_name = f"keyword_indexing_lock_{self.dataset.id}"
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
|
||||
@ -2,7 +2,7 @@ import base64
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, override
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -72,21 +72,27 @@ class _LazyEmbeddings(Embeddings):
|
||||
self._real = CacheEmbedding(embedding_model)
|
||||
return self._real
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._ensure().embed_documents(texts)
|
||||
|
||||
@override
|
||||
def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]:
|
||||
return self._ensure().embed_multimodal_documents(multimodel_documents)
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
return self._ensure().embed_query(text)
|
||||
|
||||
@override
|
||||
def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]:
|
||||
return self._ensure().embed_multimodal_query(multimodel_document)
|
||||
|
||||
@override
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return await self._ensure().aembed_documents(texts)
|
||||
|
||||
@override
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
return await self._ensure().aembed_query(text)
|
||||
|
||||
|
||||
@ -37,6 +37,10 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from core.workflow.nodes.agent_v2 import DifyAgentNode
|
||||
from core.workflow.nodes.agent_v2.binding_resolver import WorkflowAgentBindingResolver
|
||||
from core.workflow.nodes.agent_v2.output_adapter import WorkflowAgentOutputAdapter
|
||||
from core.workflow.nodes.agent_v2.runtime_request_builder import WorkflowAgentRuntimeRequestBuilder
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
|
||||
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
@ -438,12 +442,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
"tool_file_manager": self._bound_tool_file_manager_factory(),
|
||||
"runtime": self._tool_runtime,
|
||||
},
|
||||
BuiltinNodeTypes.AGENT: lambda: {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
"presentation_provider": self._agent_strategy_presentation_provider,
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
BuiltinNodeTypes.AGENT: lambda: self._build_agent_node_init_kwargs(node_class=node_class),
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
constructor_node_data = resolved_node_data.model_dump(mode="python", by_alias=True)
|
||||
@ -469,6 +468,32 @@ class DifyNodeFactory(NodeFactory):
|
||||
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
|
||||
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
|
||||
|
||||
def _build_agent_node_init_kwargs(self, *, node_class: type[Node]) -> dict[str, object]:
|
||||
if issubclass(node_class, DifyAgentNode):
|
||||
from clients.agent_backend import AgentBackendRunEventAdapter, AgentBackendRunRequestBuilder
|
||||
from clients.agent_backend.factory import create_agent_backend_run_client
|
||||
|
||||
return {
|
||||
"binding_resolver": WorkflowAgentBindingResolver(),
|
||||
"runtime_request_builder": WorkflowAgentRuntimeRequestBuilder(
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
request_builder=AgentBackendRunRequestBuilder(),
|
||||
),
|
||||
"agent_backend_client": create_agent_backend_run_client(
|
||||
base_url=dify_config.AGENT_BACKEND_BASE_URL,
|
||||
use_fake=dify_config.AGENT_BACKEND_USE_FAKE,
|
||||
fake_scenario=dify_config.AGENT_BACKEND_FAKE_SCENARIO,
|
||||
),
|
||||
"event_adapter": AgentBackendRunEventAdapter(),
|
||||
"output_adapter": WorkflowAgentOutputAdapter(),
|
||||
}
|
||||
return {
|
||||
"strategy_resolver": self._agent_strategy_resolver,
|
||||
"presentation_provider": self._agent_strategy_presentation_provider,
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
}
|
||||
|
||||
def _build_llm_compatible_node_init_kwargs(
|
||||
self,
|
||||
*,
|
||||
|
||||
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .agent_node import DifyAgentNode
|
||||
from .entities import DifyAgentNodeData
|
||||
|
||||
__all__ = ["DifyAgentNode", "DifyAgentNodeData"]
|
||||
281
api/core/workflow/nodes/agent_v2/agent_node.py
Normal file
281
api/core/workflow/nodes/agent_v2/agent_node.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendError,
|
||||
AgentBackendHTTPError,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunClient,
|
||||
AgentBackendRunEventAdapter,
|
||||
AgentBackendRunFailedInternalEvent,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
AgentBackendStreamError,
|
||||
AgentBackendStreamInternalEvent,
|
||||
AgentBackendTransportError,
|
||||
AgentBackendValidationError,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
from .binding_resolver import WorkflowAgentBindingError, WorkflowAgentBindingResolver
|
||||
from .entities import DifyAgentNodeData
|
||||
from .output_adapter import WorkflowAgentOutputAdapter
|
||||
from .runtime_request_builder import (
|
||||
WorkflowAgentRuntimeBuildContext,
|
||||
WorkflowAgentRuntimeRequestBuilder,
|
||||
WorkflowAgentRuntimeRequestBuildError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
class DifyAgentNode(Node[DifyAgentNodeData]):
|
||||
node_type = BuiltinNodeTypes.AGENT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: str,
|
||||
data: DifyAgentNodeData,
|
||||
*,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
binding_resolver: WorkflowAgentBindingResolver,
|
||||
runtime_request_builder: WorkflowAgentRuntimeRequestBuilder,
|
||||
agent_backend_client: AgentBackendRunClient,
|
||||
event_adapter: AgentBackendRunEventAdapter,
|
||||
output_adapter: WorkflowAgentOutputAdapter,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
node_id=node_id,
|
||||
data=data,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._binding_resolver = binding_resolver
|
||||
self._runtime_request_builder = runtime_request_builder
|
||||
self._agent_backend_client = agent_backend_client
|
||||
self._event_adapter = event_adapter
|
||||
self._output_adapter = output_adapter
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "2"
|
||||
|
||||
def populate_start_event(self, event) -> None:
|
||||
event.extras["agent_node"] = {"version": "2", "agent_node_kind": self.node_data.agent_node_kind}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
workflow_id = self.graph_init_params.workflow_id
|
||||
workflow_run_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID,
|
||||
)
|
||||
inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {
|
||||
"agent_backend": {
|
||||
"status": "not_started",
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
bundle = self._binding_resolver.resolve(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
workflow_id=workflow_id,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
runtime_request = self._runtime_request_builder.build(
|
||||
WorkflowAgentRuntimeBuildContext(
|
||||
dify_context=dify_ctx,
|
||||
workflow_id=workflow_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
binding=bundle.binding,
|
||||
agent=bundle.agent,
|
||||
snapshot=bundle.snapshot,
|
||||
)
|
||||
)
|
||||
inputs = {"agent_backend_request": runtime_request.redacted_request}
|
||||
metadata = dict(runtime_request.metadata)
|
||||
process_data = {
|
||||
"agent_id": bundle.agent.id,
|
||||
"agent_config_snapshot_id": bundle.snapshot.id,
|
||||
"binding_id": bundle.binding.id,
|
||||
}
|
||||
create_response = self._agent_backend_client.create_run(runtime_request.request)
|
||||
metadata["agent_backend"] = {
|
||||
**dict(metadata.get("agent_backend") or {}),
|
||||
"run_id": create_response.run_id,
|
||||
"status": create_response.status,
|
||||
}
|
||||
except WorkflowAgentBindingError as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type=error.error_code,
|
||||
)
|
||||
return
|
||||
except WorkflowAgentRuntimeRequestBuildError as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type=error.error_code,
|
||||
)
|
||||
return
|
||||
except AgentBackendError as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type=self._agent_backend_error_type(error),
|
||||
)
|
||||
return
|
||||
except Exception as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type="agent_workflow_node_runtime_error",
|
||||
)
|
||||
return
|
||||
|
||||
stream_event_count = 0
|
||||
try:
|
||||
for public_event in self._agent_backend_client.stream_events(create_response.run_id):
|
||||
stream_event_count += 1
|
||||
for internal_event in self._event_adapter.adapt(public_event):
|
||||
if internal_event.type == AgentBackendInternalEventType.RUN_STARTED:
|
||||
continue
|
||||
if internal_event.type == AgentBackendInternalEventType.STREAM_EVENT:
|
||||
if isinstance(internal_event, AgentBackendStreamInternalEvent):
|
||||
self._record_stream_metadata(metadata, internal_event)
|
||||
continue
|
||||
metadata["agent_backend"] = {
|
||||
**dict(metadata.get("agent_backend") or {}),
|
||||
"stream_event_count": stream_event_count,
|
||||
}
|
||||
if isinstance(internal_event, AgentBackendRunSucceededInternalEvent):
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_success_result(
|
||||
event=internal_event,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return
|
||||
if isinstance(
|
||||
internal_event,
|
||||
AgentBackendRunFailedInternalEvent,
|
||||
) or internal_event.type in {
|
||||
AgentBackendInternalEventType.RUN_CANCELLED,
|
||||
AgentBackendInternalEventType.RUN_PAUSED,
|
||||
}:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_failure_result(
|
||||
event=internal_event,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return
|
||||
except AgentBackendError as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type=self._agent_backend_error_type(error),
|
||||
)
|
||||
return
|
||||
except Exception as error:
|
||||
yield self._failure_event(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
error=str(error),
|
||||
error_type="agent_backend_stream_error",
|
||||
)
|
||||
return
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=self._output_adapter.build_stream_exhausted_result(
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _failure_event(
|
||||
*,
|
||||
inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
error: str,
|
||||
error_type: str,
|
||||
) -> StreamCompletedEvent:
|
||||
return StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.AGENT_LOG: metadata},
|
||||
outputs={},
|
||||
error=error,
|
||||
error_type=error_type,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _agent_backend_error_type(error: AgentBackendError) -> str:
|
||||
if isinstance(error, AgentBackendValidationError):
|
||||
return "agent_backend_validation_error"
|
||||
if isinstance(error, AgentBackendHTTPError):
|
||||
return "agent_backend_http_error"
|
||||
if isinstance(error, AgentBackendStreamError):
|
||||
return "agent_backend_stream_error"
|
||||
if isinstance(error, AgentBackendTransportError):
|
||||
return "agent_backend_transport_error"
|
||||
return "agent_backend_error"
|
||||
|
||||
@staticmethod
|
||||
def _record_stream_metadata(metadata: dict[str, Any], event: AgentBackendStreamInternalEvent) -> None:
|
||||
agent_backend = dict(metadata.get("agent_backend") or {})
|
||||
agent_backend["last_stream_event_id"] = event.source_event_id
|
||||
if event.event_kind:
|
||||
agent_backend["last_stream_event_kind"] = event.event_kind
|
||||
if isinstance(event.data, Mapping):
|
||||
usage = event.data.get("usage") or event.data.get("model_usage")
|
||||
if isinstance(usage, Mapping):
|
||||
agent_backend["usage"] = dict(usage)
|
||||
metadata["agent_backend"] = agent_backend
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: DifyAgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
del graph_config, node_id, node_data
|
||||
return {}
|
||||
93
api/core/workflow/nodes/agent_v2/binding_resolver.py
Normal file
93
api/core/workflow/nodes/agent_v2/binding_resolver.py
Normal file
@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentStatus, WorkflowAgentNodeBinding
|
||||
|
||||
|
||||
class WorkflowAgentBindingError(Exception):
|
||||
error_code: str
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowAgentBindingBundle:
|
||||
binding: WorkflowAgentNodeBinding
|
||||
agent: Agent
|
||||
snapshot: AgentConfigSnapshot
|
||||
|
||||
|
||||
class WorkflowAgentBindingResolver:
|
||||
"""Resolve the Agent binding owned by the current workflow id and node id."""
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowAgentBindingBundle:
|
||||
with session_factory.create_session() as session:
|
||||
binding = session.scalar(
|
||||
select(WorkflowAgentNodeBinding)
|
||||
.where(
|
||||
WorkflowAgentNodeBinding.tenant_id == tenant_id,
|
||||
WorkflowAgentNodeBinding.app_id == app_id,
|
||||
WorkflowAgentNodeBinding.workflow_id == workflow_id,
|
||||
WorkflowAgentNodeBinding.node_id == node_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if binding is None:
|
||||
raise WorkflowAgentBindingError(
|
||||
"agent_binding_not_found",
|
||||
f"Workflow Agent binding not found for node {node_id}.",
|
||||
)
|
||||
if binding.agent_id is None:
|
||||
raise WorkflowAgentBindingError("agent_not_available", "Workflow Agent binding has no agent.")
|
||||
if binding.current_snapshot_id is None:
|
||||
raise WorkflowAgentBindingError(
|
||||
"agent_config_snapshot_not_found",
|
||||
"Workflow Agent binding has no current config snapshot.",
|
||||
)
|
||||
|
||||
agent = session.scalar(
|
||||
select(Agent)
|
||||
.where(
|
||||
Agent.tenant_id == tenant_id,
|
||||
Agent.id == binding.agent_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if agent is None or agent.status == AgentStatus.ARCHIVED:
|
||||
raise WorkflowAgentBindingError(
|
||||
"agent_not_available",
|
||||
f"Agent {binding.agent_id} is not available.",
|
||||
)
|
||||
|
||||
snapshot = session.scalar(
|
||||
select(AgentConfigSnapshot)
|
||||
.where(
|
||||
AgentConfigSnapshot.tenant_id == tenant_id,
|
||||
AgentConfigSnapshot.agent_id == agent.id,
|
||||
AgentConfigSnapshot.id == binding.current_snapshot_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if snapshot is None:
|
||||
raise WorkflowAgentBindingError(
|
||||
"agent_config_snapshot_not_found",
|
||||
f"Agent config snapshot {binding.current_snapshot_id} not found.",
|
||||
)
|
||||
|
||||
session.expunge(binding)
|
||||
session.expunge(agent)
|
||||
session.expunge(snapshot)
|
||||
return WorkflowAgentBindingBundle(binding=binding, agent=agent, snapshot=snapshot)
|
||||
17
api/core/workflow/nodes/agent_v2/entities.py
Normal file
17
api/core/workflow/nodes/agent_v2/entities.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
|
||||
class DifyAgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
agent_node_kind: Literal["dify_agent"] = "dify_agent"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_version(self) -> "DifyAgentNodeData":
|
||||
if self.version != "2":
|
||||
raise ValueError("Dify Agent Node v2 requires version='2'")
|
||||
return self
|
||||
255
api/core/workflow/nodes/agent_v2/output_adapter.py
Normal file
255
api/core/workflow/nodes/agent_v2/output_adapter.py
Normal file
@ -0,0 +1,255 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendInternalEvent,
|
||||
AgentBackendInternalEventType,
|
||||
AgentBackendRunCancelledInternalEvent,
|
||||
AgentBackendRunFailedInternalEvent,
|
||||
AgentBackendRunPausedInternalEvent,
|
||||
AgentBackendRunSucceededInternalEvent,
|
||||
)
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.file import File, FileTransferMethod, FileType
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.node_events import NodeRunResult
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class WorkflowAgentOutputAdapter:
|
||||
"""Convert terminal Agent backend events into workflow node run results."""
|
||||
|
||||
def build_success_result(
|
||||
self,
|
||||
*,
|
||||
event: AgentBackendRunSucceededInternalEvent,
|
||||
inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
) -> NodeRunResult:
|
||||
metadata = self._with_terminal_metadata(metadata, event, "succeeded")
|
||||
usage = self._usage_from_metadata(metadata)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=self._normalize_outputs(event.output),
|
||||
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
|
||||
llm_usage=usage or LLMUsage.empty_usage(),
|
||||
)
|
||||
|
||||
def build_failure_result(
|
||||
self,
|
||||
*,
|
||||
event: (
|
||||
AgentBackendRunFailedInternalEvent
|
||||
| AgentBackendRunCancelledInternalEvent
|
||||
| AgentBackendRunPausedInternalEvent
|
||||
),
|
||||
inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
) -> NodeRunResult:
|
||||
status = WorkflowNodeExecutionStatus.FAILED
|
||||
error = "Agent backend run failed."
|
||||
error_type = "agent_backend_run_failed"
|
||||
terminal_status = "failed"
|
||||
|
||||
match event:
|
||||
case AgentBackendRunFailedInternalEvent():
|
||||
error = event.error
|
||||
error_type = event.reason or "agent_backend_run_failed"
|
||||
terminal_status = "failed"
|
||||
case AgentBackendRunCancelledInternalEvent():
|
||||
error = event.message or "Agent backend run was cancelled."
|
||||
error_type = "agent_backend_run_cancelled"
|
||||
terminal_status = "cancelled"
|
||||
case AgentBackendRunPausedInternalEvent():
|
||||
error = event.message or "Agent backend run paused, but workflow Agent Node pause is not supported yet."
|
||||
error_type = "agent_backend_paused_unsupported"
|
||||
terminal_status = "paused"
|
||||
|
||||
metadata = self._with_terminal_metadata(metadata, event, terminal_status)
|
||||
usage = self._usage_from_metadata(metadata)
|
||||
return NodeRunResult(
|
||||
status=status,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
|
||||
llm_usage=usage or LLMUsage.empty_usage(),
|
||||
error=error,
|
||||
error_type=error_type,
|
||||
)
|
||||
|
||||
def build_stream_exhausted_result(
|
||||
self,
|
||||
*,
|
||||
inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
metadata: dict[str, Any],
|
||||
) -> NodeRunResult:
|
||||
usage = self._usage_from_metadata(metadata)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
metadata=self._build_node_metadata(metadata=metadata, usage=usage),
|
||||
llm_usage=usage or LLMUsage.empty_usage(),
|
||||
error="Agent backend stream ended before a terminal event.",
|
||||
error_type="agent_backend_stream_error",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_outputs(cls, output: Any) -> dict[str, Any]:
|
||||
if isinstance(output, dict):
|
||||
if cls._is_file_payload(output):
|
||||
return {"file": cls._file_segment_from_payload(output)}
|
||||
return {key: cls._normalize_output_value(value) for key, value in output.items()}
|
||||
if isinstance(output, str):
|
||||
return {"text": output}
|
||||
return {"result": output}
|
||||
|
||||
@classmethod
|
||||
def _normalize_output_value(cls, value: Any) -> Any:
|
||||
if isinstance(value, File | FileSegment | ArrayFileSegment):
|
||||
return value
|
||||
if isinstance(value, Mapping):
|
||||
if cls._is_file_payload(value):
|
||||
return cls._file_segment_from_payload(value)
|
||||
return {key: cls._normalize_output_value(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
if value and all(isinstance(item, Mapping) and cls._is_file_payload(item) for item in value):
|
||||
return ArrayFileSegment(value=[cls._file_from_payload(item) for item in value])
|
||||
return [cls._normalize_output_value(item) for item in value]
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _is_file_payload(value: Mapping[str, Any]) -> bool:
|
||||
return any(value.get(key) for key in ("file_id", "upload_file_id", "tool_file_id", "url", "remote_url")) and (
|
||||
"filename" in value or "mime_type" in value or "url" in value or "remote_url" in value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _file_segment_from_payload(cls, value: Mapping[str, Any]) -> FileSegment:
|
||||
return FileSegment(value=cls._file_from_payload(value))
|
||||
|
||||
@classmethod
|
||||
def _file_from_payload(cls, value: Mapping[str, Any]) -> File:
|
||||
remote_url = cls._string_value(value.get("remote_url") or value.get("url"))
|
||||
upload_file_id = cls._string_value(value.get("upload_file_id") or value.get("file_id"))
|
||||
tool_file_id = cls._string_value(value.get("tool_file_id"))
|
||||
filename = cls._string_value(value.get("filename") or value.get("name"))
|
||||
mime_type = cls._string_value(value.get("mime_type") or value.get("mimetype"))
|
||||
extension = cls._extension_from_payload(value, filename)
|
||||
file_type = cls._file_type_from_payload(value, mime_type)
|
||||
size = value.get("size")
|
||||
if not isinstance(size, int):
|
||||
size = -1
|
||||
|
||||
if tool_file_id:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
related_id = tool_file_id
|
||||
elif remote_url:
|
||||
transfer_method = FileTransferMethod.REMOTE_URL
|
||||
related_id = None
|
||||
else:
|
||||
transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
related_id = upload_file_id
|
||||
|
||||
return File(
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=remote_url if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=related_id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _string_value(value: Any) -> str | None:
|
||||
return value if isinstance(value, str) and value else None
|
||||
|
||||
@classmethod
|
||||
def _extension_from_payload(cls, value: Mapping[str, Any], filename: str | None) -> str | None:
|
||||
extension = cls._string_value(value.get("extension"))
|
||||
if extension:
|
||||
return extension if extension.startswith(".") else f".{extension}"
|
||||
if filename and "." in filename:
|
||||
return f".{filename.rsplit('.', 1)[1]}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _file_type_from_payload(value: Mapping[str, Any], mime_type: str | None) -> FileType:
|
||||
explicit_type = value.get("type") or value.get("file_type")
|
||||
if isinstance(explicit_type, str):
|
||||
try:
|
||||
return FileType(explicit_type)
|
||||
except ValueError:
|
||||
pass
|
||||
if mime_type:
|
||||
if mime_type.startswith("image/"):
|
||||
return FileType.IMAGE
|
||||
if mime_type.startswith("audio/"):
|
||||
return FileType.AUDIO
|
||||
if mime_type.startswith("video/"):
|
||||
return FileType.VIDEO
|
||||
return FileType.DOCUMENT
|
||||
return FileType.CUSTOM
|
||||
|
||||
@staticmethod
|
||||
def _usage_from_metadata(metadata: Mapping[str, Any]) -> LLMUsage | None:
|
||||
agent_backend = metadata.get("agent_backend")
|
||||
if not isinstance(agent_backend, Mapping):
|
||||
return None
|
||||
usage = agent_backend.get("usage")
|
||||
if not isinstance(usage, Mapping):
|
||||
return None
|
||||
try:
|
||||
return LLMUsage.from_metadata(usage)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_node_metadata(
|
||||
*,
|
||||
metadata: dict[str, Any],
|
||||
usage: LLMUsage | None,
|
||||
) -> dict[WorkflowNodeExecutionMetadataKey, Any]:
|
||||
node_metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: metadata,
|
||||
}
|
||||
if usage is not None:
|
||||
node_metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
node_metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
node_metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
return node_metadata
|
||||
|
||||
@staticmethod
|
||||
def _with_terminal_metadata(
|
||||
metadata: dict[str, Any],
|
||||
event: AgentBackendInternalEvent,
|
||||
terminal_status: str,
|
||||
) -> dict[str, Any]:
|
||||
updated = dict(metadata)
|
||||
agent_backend = dict(updated.get("agent_backend") or {})
|
||||
agent_backend.update(
|
||||
{
|
||||
"run_id": event.run_id,
|
||||
"terminal_event_id": event.source_event_id,
|
||||
"status": terminal_status,
|
||||
}
|
||||
)
|
||||
session_snapshot = None
|
||||
if isinstance(event, AgentBackendRunSucceededInternalEvent | AgentBackendRunPausedInternalEvent):
|
||||
session_snapshot = event.session_snapshot
|
||||
if session_snapshot is not None:
|
||||
agent_backend["session_snapshot"] = {
|
||||
"layer_count": len(session_snapshot.layers),
|
||||
}
|
||||
updated["agent_backend"] = agent_backend
|
||||
updated["terminal_event_type"] = AgentBackendInternalEventType(event.type).value
|
||||
return updated
|
||||
55
api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py
Normal file
55
api/core/workflow/nodes/agent_v2/runtime_feature_manifest.py
Normal file
@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from models.agent_config_entities import AgentSoulConfig
|
||||
|
||||
SUPPORTED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
{
|
||||
"system_prompt",
|
||||
"workflow_prompt",
|
||||
"workflow_context",
|
||||
"model",
|
||||
"structured_output",
|
||||
}
|
||||
)
|
||||
|
||||
RESERVED_AGENT_BACKEND_FEATURES = frozenset(
|
||||
{
|
||||
"skills_files",
|
||||
"tools",
|
||||
"knowledge",
|
||||
"human",
|
||||
"env",
|
||||
"sandbox",
|
||||
"memory",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_runtime_feature_manifest(agent_soul: AgentSoulConfig) -> dict[str, Any]:
|
||||
"""Describe PRD capabilities that are persisted but not executed in phase 3."""
|
||||
warnings: list[dict[str, str]] = []
|
||||
soul_dump = agent_soul.model_dump(mode="json")
|
||||
for section in sorted(RESERVED_AGENT_BACKEND_FEATURES):
|
||||
value = soul_dump.get(section)
|
||||
has_value = bool(value)
|
||||
if isinstance(value, dict):
|
||||
has_value = any(bool(item) for item in value.values())
|
||||
if has_value:
|
||||
warnings.append(
|
||||
{
|
||||
"section": f"agent_soul.{section}",
|
||||
"code": "agent_backend_layer_not_available",
|
||||
"message": f"{section} is saved in Agent Soul but is not executed by Agent backend in phase 3.",
|
||||
}
|
||||
)
|
||||
|
||||
reserved_status = dict.fromkeys(sorted(RESERVED_AGENT_BACKEND_FEATURES), "reserved_not_executed")
|
||||
|
||||
return {
|
||||
"supported": sorted(SUPPORTED_AGENT_BACKEND_FEATURES),
|
||||
"reserved": sorted(RESERVED_AGENT_BACKEND_FEATURES),
|
||||
"reserved_status": reserved_status,
|
||||
"unsupported_runtime_warnings": warnings,
|
||||
}
|
||||
288
api/core/workflow/nodes/agent_v2/runtime_request_builder.py
Normal file
288
api/core/workflow/nodes/agent_v2/runtime_request_builder.py
Normal file
@ -0,0 +1,288 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Protocol, cast
|
||||
|
||||
from dify_agent.protocol import CreateRunRequest, ExecutionContext
|
||||
|
||||
from clients.agent_backend import (
|
||||
AgentBackendModelConfig,
|
||||
AgentBackendOutputConfig,
|
||||
AgentBackendRunRequestBuilder,
|
||||
AgentBackendWorkflowNodeRunInput,
|
||||
redact_for_agent_backend_log,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
from graphon.variables.segments import Segment
|
||||
from models.agent import Agent, AgentConfigSnapshot, WorkflowAgentNodeBinding
|
||||
from models.agent_config_entities import (
|
||||
AgentSoulConfig,
|
||||
DeclaredOutputConfig,
|
||||
DeclaredOutputType,
|
||||
WorkflowNodeJobConfig,
|
||||
)
|
||||
|
||||
from .runtime_feature_manifest import build_runtime_feature_manifest
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeRequestBuildError(ValueError):
|
||||
"""Raised when workflow state cannot be mapped to a valid Agent backend run request."""
|
||||
|
||||
def __init__(self, error_code: str, message: str) -> None:
|
||||
self.error_code = error_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class VariablePoolReader(Protocol):
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None: ...
|
||||
|
||||
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ...
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowAgentRuntimeBuildContext:
|
||||
dify_context: DifyRunContext
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
node_execution_id: str
|
||||
variable_pool: VariablePoolReader
|
||||
binding: WorkflowAgentNodeBinding
|
||||
agent: Agent
|
||||
snapshot: AgentConfigSnapshot
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class WorkflowAgentRuntimeRequest:
|
||||
request: CreateRunRequest
|
||||
redacted_request: dict[str, Any]
|
||||
agent_soul: AgentSoulConfig
|
||||
node_job: WorkflowNodeJobConfig
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
class WorkflowAgentRuntimeRequestBuilder:
|
||||
"""Build public Dify Agent run requests from workflow Agent v2 runtime state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
request_builder: AgentBackendRunRequestBuilder | None = None,
|
||||
) -> None:
|
||||
self._credentials_provider = credentials_provider
|
||||
self._request_builder = request_builder or AgentBackendRunRequestBuilder()
|
||||
|
||||
def build(self, context: WorkflowAgentRuntimeBuildContext) -> WorkflowAgentRuntimeRequest:
|
||||
agent_soul = AgentSoulConfig.model_validate(context.snapshot.config_snapshot_dict)
|
||||
node_job = WorkflowNodeJobConfig.model_validate(context.binding.node_job_config_dict)
|
||||
if agent_soul.model is None:
|
||||
raise WorkflowAgentRuntimeRequestBuildError(
|
||||
"agent_model_not_configured",
|
||||
"Workflow Agent node requires Agent Soul model config.",
|
||||
)
|
||||
|
||||
metadata = self._build_metadata(context, agent_soul, node_job)
|
||||
workflow_context_prompt = self._build_workflow_context_prompt(context, node_job)
|
||||
workflow_job_prompt = node_job.workflow_prompt.strip() or "Run this workflow Agent Node for the current run."
|
||||
user_prompt = workflow_context_prompt.strip() or "Use the current workflow context."
|
||||
credentials = self._credentials_provider.fetch(agent_soul.model.model_provider, agent_soul.model.model)
|
||||
|
||||
request = self._request_builder.build_for_workflow_node(
|
||||
AgentBackendWorkflowNodeRunInput(
|
||||
model=AgentBackendModelConfig(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
plugin_id=agent_soul.model.plugin_id,
|
||||
model_provider=agent_soul.model.model_provider,
|
||||
model=agent_soul.model.model,
|
||||
user_id=context.dify_context.user_id,
|
||||
credentials=self._normalize_credentials(credentials),
|
||||
model_settings=cast(dict[str, Any], agent_soul.model.model_settings),
|
||||
),
|
||||
execution_context=ExecutionContext(
|
||||
tenant_id=context.dify_context.tenant_id,
|
||||
app_id=context.dify_context.app_id,
|
||||
workflow_id=context.workflow_id,
|
||||
workflow_run_id=context.workflow_run_id,
|
||||
node_id=context.node_id,
|
||||
node_execution_id=context.node_execution_id,
|
||||
conversation_id=get_system_text(context.variable_pool, SystemVariableKey.CONVERSATION_ID),
|
||||
agent_id=context.agent.id,
|
||||
agent_config_version_id=context.snapshot.id,
|
||||
invoke_from=self._agent_backend_invoke_from(context.dify_context.invoke_from),
|
||||
),
|
||||
agent_soul_prompt=agent_soul.prompt.system_prompt or None,
|
||||
workflow_node_job_prompt=workflow_job_prompt,
|
||||
user_prompt=user_prompt,
|
||||
output=self._build_output_config(node_job.declared_outputs),
|
||||
idempotency_key=self._idempotency_key(context),
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
redacted = cast(dict[str, Any], redact_for_agent_backend_log(request))
|
||||
return WorkflowAgentRuntimeRequest(
|
||||
request=request,
|
||||
redacted_request=redacted,
|
||||
agent_soul=agent_soul,
|
||||
node_job=node_job,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _agent_backend_invoke_from(invoke_from: InvokeFrom) -> Literal["workflow_run", "single_step"]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.VALIDATION}:
|
||||
return "single_step"
|
||||
return "workflow_run"
|
||||
|
||||
@staticmethod
|
||||
def _idempotency_key(context: WorkflowAgentRuntimeBuildContext) -> str:
|
||||
if context.workflow_run_id:
|
||||
return f"{context.workflow_run_id}:{context.node_execution_id}"
|
||||
return context.node_execution_id
|
||||
|
||||
@staticmethod
|
||||
def _build_metadata(
|
||||
context: WorkflowAgentRuntimeBuildContext,
|
||||
agent_soul: AgentSoulConfig,
|
||||
node_job: WorkflowNodeJobConfig,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"tenant_id": context.dify_context.tenant_id,
|
||||
"app_id": context.dify_context.app_id,
|
||||
"workflow_id": context.workflow_id,
|
||||
"workflow_run_id": context.workflow_run_id,
|
||||
"node_id": context.node_id,
|
||||
"node_execution_id": context.node_execution_id,
|
||||
"agent_id": context.agent.id,
|
||||
"agent_config_snapshot_id": context.snapshot.id,
|
||||
"binding_id": context.binding.id,
|
||||
"workflow_node_job_mode": node_job.mode.value,
|
||||
"runtime_support": build_runtime_feature_manifest(agent_soul),
|
||||
}
|
||||
|
||||
def _build_workflow_context_prompt(
|
||||
self,
|
||||
context: WorkflowAgentRuntimeBuildContext,
|
||||
node_job: WorkflowNodeJobConfig,
|
||||
) -> str:
|
||||
lines = ["Workflow context loaded for this run:"]
|
||||
query = get_system_text(context.variable_pool, SystemVariableKey.QUERY)
|
||||
if query:
|
||||
lines.append(f"- User query: {query}")
|
||||
|
||||
resolved_outputs = self._resolve_previous_node_outputs(
|
||||
context.variable_pool,
|
||||
node_job.previous_node_output_refs,
|
||||
)
|
||||
if resolved_outputs:
|
||||
lines.append("- Previous node outputs:")
|
||||
for item in resolved_outputs:
|
||||
lines.append(f" - {item['label']}: {item['value']}")
|
||||
|
||||
lines.append("The above workflow context is run-specific. Do not treat it as Agent Soul or persistent memory.")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _resolve_previous_node_outputs(
|
||||
self,
|
||||
variable_pool: VariablePoolReader,
|
||||
refs: Sequence[Mapping[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
resolved: list[dict[str, Any]] = []
|
||||
for ref in refs:
|
||||
selector = self._selector_from_ref(ref)
|
||||
if not selector:
|
||||
raise WorkflowAgentRuntimeRequestBuildError(
|
||||
"invalid_previous_node_output_ref",
|
||||
"Workflow Agent node has invalid previous node output ref.",
|
||||
)
|
||||
segment = variable_pool.get(selector)
|
||||
if segment is None:
|
||||
raise WorkflowAgentRuntimeRequestBuildError(
|
||||
"missing_previous_node_output",
|
||||
f"Workflow Agent node cannot resolve previous node output {'.'.join(selector)}.",
|
||||
)
|
||||
value = getattr(segment, "value", None)
|
||||
resolved.append(
|
||||
{
|
||||
"label": ".".join(selector),
|
||||
"value": self._summarize_value(value),
|
||||
}
|
||||
)
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _selector_from_ref(ref: Mapping[str, Any]) -> list[str] | None:
|
||||
for key in ("selector", "variable_selector", "value_selector"):
|
||||
value = ref.get(key)
|
||||
if isinstance(value, list) and all(isinstance(item, str) for item in value):
|
||||
return value
|
||||
node_id = ref.get("node_id")
|
||||
output_name = ref.get("output") or ref.get("name") or ref.get("variable") or ref.get("key")
|
||||
if isinstance(node_id, str) and isinstance(output_name, str):
|
||||
return [node_id, output_name]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _summarize_value(value: Any) -> str:
|
||||
text = str(value)
|
||||
if len(text) > 2000:
|
||||
return text[:2000] + "...[truncated]"
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _build_output_config(declared_outputs: Sequence[DeclaredOutputConfig]) -> AgentBackendOutputConfig | None:
|
||||
if not declared_outputs:
|
||||
return None
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for output in declared_outputs:
|
||||
properties[output.name] = WorkflowAgentRuntimeRequestBuilder._schema_for_declared_output(output)
|
||||
if output.required:
|
||||
required.append(output.name)
|
||||
schema: dict[str, Any] = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
return AgentBackendOutputConfig(json_schema=schema)
|
||||
|
||||
@staticmethod
|
||||
def _schema_for_declared_output(output: DeclaredOutputConfig) -> dict[str, Any]:
|
||||
match output.type:
|
||||
case DeclaredOutputType.STRING:
|
||||
schema: dict[str, Any] = {"type": "string"}
|
||||
case DeclaredOutputType.NUMBER:
|
||||
schema = {"type": "number"}
|
||||
case DeclaredOutputType.BOOLEAN:
|
||||
schema = {"type": "boolean"}
|
||||
case DeclaredOutputType.OBJECT:
|
||||
schema = {"type": "object"}
|
||||
case DeclaredOutputType.ARRAY:
|
||||
schema = {"type": "array"}
|
||||
case DeclaredOutputType.FILE:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {"type": "string"},
|
||||
"filename": {"type": "string"},
|
||||
"mime_type": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
},
|
||||
}
|
||||
if output.description:
|
||||
schema["description"] = output.description
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _normalize_credentials(credentials: Mapping[str, Any]) -> dict[str, str | int | float | bool | None]:
|
||||
normalized: dict[str, str | int | float | bool | None] = {}
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, str | int | float | bool) or value is None:
|
||||
normalized[key] = value
|
||||
else:
|
||||
normalized[key] = str(value)
|
||||
return normalized
|
||||
388
api/core/workflow/nodes/agent_v2/validators.py
Normal file
388
api/core/workflow/nodes/agent_v2/validators.py
Normal file
@ -0,0 +1,388 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from models.agent import Agent, AgentConfigSnapshot, AgentStatus, WorkflowAgentNodeBinding
|
||||
from models.agent_config_entities import AgentSoulConfig, WorkflowNodeJobConfig
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow
|
||||
|
||||
from .entities import DifyAgentNodeData
|
||||
|
||||
|
||||
class WorkflowAgentNodeValidationError(ValueError):
|
||||
"""Raised when a Workflow Agent v2 node cannot be executed or published."""
|
||||
|
||||
|
||||
class WorkflowAgentNodeValidator:
|
||||
"""Validate Agent v2 workflow nodes against graph topology and persisted bindings."""
|
||||
|
||||
_LOCKED_AGENT_SOUL_KEYS = frozenset(
|
||||
{
|
||||
"agent_soul",
|
||||
"soul",
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
"skills_files",
|
||||
"skills",
|
||||
"files",
|
||||
"tools",
|
||||
"dify_tools",
|
||||
"cli_tools",
|
||||
"knowledge",
|
||||
"env",
|
||||
"environment",
|
||||
"sandbox",
|
||||
"sandbox_provider",
|
||||
"memory",
|
||||
"memory_strategy",
|
||||
"model",
|
||||
"app_features",
|
||||
"app_variables",
|
||||
"misc_legacy",
|
||||
}
|
||||
)
|
||||
_SUPPORTED_HUMAN_CONTACT_CHANNELS = frozenset({"email", "slack", "web_app", "webapp", "chat"})
|
||||
|
||||
@classmethod
|
||||
def validate_draft_workflow(cls, *, session: Session, workflow: Workflow) -> None:
|
||||
cls._validate_workflow(session=session, workflow=workflow, require_binding=False)
|
||||
|
||||
@classmethod
|
||||
def validate_published_workflow(cls, *, session: Session, workflow: Workflow) -> None:
|
||||
cls._validate_workflow(session=session, workflow=workflow, require_binding=True)
|
||||
|
||||
@classmethod
|
||||
def _validate_workflow(cls, *, session: Session, workflow: Workflow, require_binding: bool) -> None:
|
||||
graph = workflow.graph_dict
|
||||
topology = _WorkflowGraphTopology.from_graph(graph)
|
||||
for node_id, node_data in cls.iter_agent_v2_nodes(graph):
|
||||
cls._validate_node_schema(node_id=node_id, node_data=node_data)
|
||||
binding = cls._find_binding(
|
||||
session=session,
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
if binding is None:
|
||||
if require_binding:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {node_id} requires a binding before publishing."
|
||||
)
|
||||
continue
|
||||
cls.validate_binding(session=session, binding=binding, topology=topology)
|
||||
|
||||
@classmethod
|
||||
def validate_binding(
|
||||
cls,
|
||||
*,
|
||||
session: Session,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
topology: _WorkflowGraphTopology | None = None,
|
||||
) -> None:
|
||||
if binding.agent_id is None:
|
||||
raise WorkflowAgentNodeValidationError(f"Workflow Agent node {binding.node_id} is missing agent binding.")
|
||||
if binding.current_snapshot_id is None:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} is missing config snapshot binding."
|
||||
)
|
||||
|
||||
agent = session.scalar(
|
||||
select(Agent)
|
||||
.where(
|
||||
Agent.tenant_id == binding.tenant_id,
|
||||
Agent.id == binding.agent_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if agent is None or agent.status == AgentStatus.ARCHIVED:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references an unavailable agent."
|
||||
)
|
||||
|
||||
snapshot = session.scalar(
|
||||
select(AgentConfigSnapshot)
|
||||
.where(
|
||||
AgentConfigSnapshot.tenant_id == binding.tenant_id,
|
||||
AgentConfigSnapshot.agent_id == agent.id,
|
||||
AgentConfigSnapshot.id == binding.current_snapshot_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if snapshot is None:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references a missing config snapshot."
|
||||
)
|
||||
|
||||
agent_soul = AgentSoulConfig.model_validate(snapshot.config_snapshot_dict)
|
||||
if agent_soul.model is None:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} requires Agent Soul model config."
|
||||
)
|
||||
node_job = WorkflowNodeJobConfig.model_validate(binding.node_job_config_dict)
|
||||
cls.validate_node_job(session=session, binding=binding, node_job=node_job, topology=topology)
|
||||
|
||||
@classmethod
|
||||
def validate_node_job(
|
||||
cls,
|
||||
*,
|
||||
session: Session,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
node_job: WorkflowNodeJobConfig,
|
||||
topology: _WorkflowGraphTopology | None = None,
|
||||
) -> None:
|
||||
cls._validate_locked_agent_soul_not_overridden(binding=binding, node_job=node_job)
|
||||
|
||||
output_names: set[str] = set()
|
||||
for output in node_job.declared_outputs:
|
||||
if output.name in output_names:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} has duplicate output name {output.name}."
|
||||
)
|
||||
output_names.add(output.name)
|
||||
for check in output.checks:
|
||||
if check.benchmark_file_ref is not None:
|
||||
cls._validate_file_ref(
|
||||
session=session,
|
||||
binding=binding,
|
||||
file_ref=check.benchmark_file_ref,
|
||||
ref_context=f"output {output.name} benchmark file",
|
||||
)
|
||||
|
||||
for ref in node_job.previous_node_output_refs:
|
||||
selector = cls.selector_from_ref(ref)
|
||||
if selector is None:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} has invalid previous node output ref."
|
||||
)
|
||||
if topology is None:
|
||||
continue
|
||||
if len(selector) < 2:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} has incomplete previous node output ref."
|
||||
)
|
||||
source_node_id = selector[0]
|
||||
if not topology.has_node(source_node_id):
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references missing previous node {source_node_id}."
|
||||
)
|
||||
if not topology.is_upstream(source_node_id=source_node_id, target_node_id=binding.node_id):
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references non-upstream previous node {source_node_id}."
|
||||
)
|
||||
|
||||
for human_ref in node_job.human_contacts:
|
||||
cls._validate_human_ref(binding=binding, human_ref=human_ref)
|
||||
|
||||
file_refs = node_job.metadata.get("file_refs")
|
||||
if isinstance(file_refs, list):
|
||||
for file_ref in file_refs:
|
||||
if isinstance(file_ref, Mapping):
|
||||
cls._validate_file_ref(
|
||||
session=session,
|
||||
binding=binding,
|
||||
file_ref=file_ref,
|
||||
ref_context="metadata file ref",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def iter_agent_v2_nodes(graph_dict: Mapping[str, Any]) -> Iterator[tuple[str, Mapping[str, Any]]]:
|
||||
nodes = graph_dict.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
return
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
node_data = node.get("data")
|
||||
if not isinstance(node_id, str) or not isinstance(node_data, Mapping):
|
||||
continue
|
||||
if node_data.get("type") == BuiltinNodeTypes.AGENT and str(node_data.get("version")) == "2":
|
||||
yield node_id, node_data
|
||||
|
||||
@staticmethod
|
||||
def selector_from_ref(ref: Mapping[str, Any]) -> list[str] | None:
|
||||
for key in ("selector", "variable_selector", "value_selector"):
|
||||
value = ref.get(key)
|
||||
if isinstance(value, list) and all(isinstance(item, str) for item in value):
|
||||
return value
|
||||
node_id = ref.get("node_id")
|
||||
output_name = ref.get("output") or ref.get("name") or ref.get("variable") or ref.get("key")
|
||||
if isinstance(node_id, str) and isinstance(output_name, str):
|
||||
return [node_id, output_name]
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_node_schema(*, node_id: str, node_data: Mapping[str, Any]) -> None:
|
||||
try:
|
||||
DifyAgentNodeData.model_validate(node_data)
|
||||
except ValueError as exc:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {node_id} has invalid Agent v2 node schema: {exc}"
|
||||
) from exc
|
||||
|
||||
@classmethod
|
||||
def _validate_locked_agent_soul_not_overridden(
|
||||
cls,
|
||||
*,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
node_job: WorkflowNodeJobConfig,
|
||||
) -> None:
|
||||
forbidden_paths = cls._find_locked_agent_soul_paths(node_job.metadata)
|
||||
if forbidden_paths:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} cannot override locked Agent Soul fields: "
|
||||
f"{', '.join(sorted(forbidden_paths))}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _find_locked_agent_soul_paths(cls, value: Any, *, path: str = "metadata") -> set[str]:
|
||||
if not isinstance(value, Mapping):
|
||||
return set()
|
||||
forbidden: set[str] = set()
|
||||
for key, item in value.items():
|
||||
key_text = str(key)
|
||||
if key_text in cls._LOCKED_AGENT_SOUL_KEYS:
|
||||
forbidden.add(f"{path}.{key_text}")
|
||||
forbidden.update(cls._find_locked_agent_soul_paths(item, path=f"{path}.{key_text}"))
|
||||
return forbidden
|
||||
|
||||
@classmethod
|
||||
def _validate_human_ref(
|
||||
cls,
|
||||
*,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
human_ref: Mapping[str, Any],
|
||||
) -> None:
|
||||
contact_id = human_ref.get("contact_id") or human_ref.get("human_id") or human_ref.get("id")
|
||||
if not isinstance(contact_id, str) or not contact_id:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} has invalid human contact ref."
|
||||
)
|
||||
|
||||
tenant_id = human_ref.get("tenant_id")
|
||||
if tenant_id is not None and tenant_id != binding.tenant_id:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references out-of-scope human contact {contact_id}."
|
||||
)
|
||||
|
||||
channel = human_ref.get("channel") or human_ref.get("method") or human_ref.get("contact_method")
|
||||
if channel is not None and channel not in cls._SUPPORTED_HUMAN_CONTACT_CHANNELS:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references unsupported human contact channel {channel}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_file_ref(
|
||||
*,
|
||||
session: Session,
|
||||
binding: WorkflowAgentNodeBinding,
|
||||
file_ref: Mapping[str, Any],
|
||||
ref_context: str,
|
||||
) -> None:
|
||||
tenant_id = file_ref.get("tenant_id")
|
||||
if tenant_id is not None and tenant_id != binding.tenant_id:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references out-of-scope {ref_context}."
|
||||
)
|
||||
|
||||
upload_file_id = (
|
||||
file_ref.get("upload_file_id") or file_ref.get("file_id") or file_ref.get("id") or file_ref.get("reference")
|
||||
)
|
||||
if upload_file_id is None and (file_ref.get("url") or file_ref.get("remote_url")):
|
||||
return
|
||||
if not isinstance(upload_file_id, str) or not upload_file_id:
|
||||
raise WorkflowAgentNodeValidationError(f"Workflow Agent node {binding.node_id} has invalid {ref_context}.")
|
||||
|
||||
upload_file = session.scalar(
|
||||
select(UploadFile)
|
||||
.where(
|
||||
UploadFile.tenant_id == binding.tenant_id,
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if upload_file is None:
|
||||
raise WorkflowAgentNodeValidationError(
|
||||
f"Workflow Agent node {binding.node_id} references missing or out-of-scope {ref_context}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _find_binding(
|
||||
*,
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> WorkflowAgentNodeBinding | None:
|
||||
return session.scalar(
|
||||
select(WorkflowAgentNodeBinding)
|
||||
.where(
|
||||
WorkflowAgentNodeBinding.tenant_id == tenant_id,
|
||||
WorkflowAgentNodeBinding.app_id == app_id,
|
||||
WorkflowAgentNodeBinding.workflow_id == workflow_id,
|
||||
WorkflowAgentNodeBinding.node_id == node_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowGraphTopology:
|
||||
def __init__(self, *, node_ids: set[str], incoming: Mapping[str, Sequence[str]]) -> None:
|
||||
self._node_ids = node_ids
|
||||
self._incoming = incoming
|
||||
|
||||
@classmethod
|
||||
def from_graph(cls, graph: Mapping[str, Any]) -> _WorkflowGraphTopology:
|
||||
node_ids = cls._node_ids_from_graph(graph)
|
||||
incoming: dict[str, list[str]] = defaultdict(list)
|
||||
edges = graph.get("edges")
|
||||
if isinstance(edges, list):
|
||||
for edge in edges:
|
||||
if not isinstance(edge, Mapping):
|
||||
continue
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
if isinstance(source, str) and isinstance(target, str):
|
||||
incoming[target].append(source)
|
||||
return cls(node_ids=node_ids, incoming=incoming)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
return node_id in self._node_ids
|
||||
|
||||
def is_upstream(self, *, source_node_id: str, target_node_id: str) -> bool:
|
||||
if source_node_id == target_node_id:
|
||||
return False
|
||||
visited: set[str] = set()
|
||||
queue: deque[str] = deque(self._incoming.get(target_node_id, ()))
|
||||
while queue:
|
||||
candidate = queue.popleft()
|
||||
if candidate == source_node_id:
|
||||
return True
|
||||
if candidate in visited:
|
||||
continue
|
||||
visited.add(candidate)
|
||||
queue.extend(self._incoming.get(candidate, ()))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _node_ids_from_graph(graph: Mapping[str, Any]) -> set[str]:
|
||||
node_ids: set[str] = set()
|
||||
nodes = graph.get("nodes")
|
||||
if not isinstance(nodes, list):
|
||||
return node_ids
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
if isinstance(node_id, str):
|
||||
node_ids.add(node_id)
|
||||
return node_ids
|
||||
664
api/dev/lint_response_contracts.py
Normal file
664
api/dev/lint_response_contracts.py
Normal file
@ -0,0 +1,664 @@
|
||||
"""Lint Flask-RESTX response docs against statically visible response serializers.
|
||||
|
||||
This checker intentionally stays conservative. It only reports a hard schema
|
||||
mismatch when both sides are statically known for the same 2xx status code:
|
||||
a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)``
|
||||
or ``Model.model_validate(...).model_dump()`` return.
|
||||
|
||||
Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing
|
||||
response schemas, and returns with non-literal status codes are classified as
|
||||
unknown so reviewers can triage them without blocking unrelated work. The one
|
||||
intentional non-schema mismatch is a known body/schema on a no-body status such
|
||||
as 204, 205, or 304.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"}
|
||||
NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value}
|
||||
DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web")
|
||||
|
||||
type Classification = Literal["valid", "mismatch", "unknown", "refactorable"]
|
||||
type ActualKind = Literal[
|
||||
"empty",
|
||||
"model",
|
||||
"model_dump_variable",
|
||||
"none",
|
||||
"raw_dict",
|
||||
"raw_list",
|
||||
"raw_value",
|
||||
"unknown",
|
||||
]
|
||||
type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef
|
||||
|
||||
HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus}
|
||||
HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus})
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentedResponse:
|
||||
status: int
|
||||
model: str | None
|
||||
line: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ActualResponse:
|
||||
status: int | None
|
||||
kind: ActualKind
|
||||
model: str | None
|
||||
line: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContractCheck:
|
||||
classification: Classification
|
||||
file: str
|
||||
class_name: str
|
||||
method: str
|
||||
route: str
|
||||
line: int
|
||||
reason: str
|
||||
documented: dict[int, str | None]
|
||||
actual: list[ActualResponse]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContractCheckContext:
|
||||
"""Stable route-method context shared by every classification result."""
|
||||
|
||||
file: str
|
||||
class_name: str
|
||||
method: str
|
||||
route: str
|
||||
line: int
|
||||
documented: dict[int, str | None]
|
||||
|
||||
def build(
|
||||
self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse]
|
||||
) -> ContractCheck:
|
||||
return ContractCheck(
|
||||
classification=classification,
|
||||
file=self.file,
|
||||
class_name=self.class_name,
|
||||
method=self.method,
|
||||
route=self.route,
|
||||
line=self.line,
|
||||
reason=reason,
|
||||
documented=self.documented,
|
||||
actual=list(actual_responses),
|
||||
)
|
||||
|
||||
def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck:
|
||||
return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual])
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariableAssignmentSummary:
|
||||
"""Track whether a local name is safe to treat as one specific response model."""
|
||||
|
||||
known_models: set[str] = field(default_factory=set)
|
||||
has_unknown_assignment: bool = False
|
||||
|
||||
def add_known(self, model: str) -> None:
|
||||
self.known_models.add(model)
|
||||
|
||||
def add_unknown(self) -> None:
|
||||
self.has_unknown_assignment = True
|
||||
|
||||
def single_known_model(self) -> str | None:
|
||||
if self.has_unknown_assignment or len(self.known_models) != 1:
|
||||
return None
|
||||
return next(iter(self.known_models))
|
||||
|
||||
|
||||
def dotted_name(node: ast.AST) -> str | None:
|
||||
match node:
|
||||
case ast.Name():
|
||||
return node.id
|
||||
case ast.Attribute():
|
||||
parent = dotted_name(node.value)
|
||||
if parent:
|
||||
return f"{parent}.{node.attr}"
|
||||
return node.attr
|
||||
return None
|
||||
|
||||
|
||||
def leaf_name(node: ast.AST) -> str | None:
|
||||
name = dotted_name(node)
|
||||
if name is None:
|
||||
return None
|
||||
return name.rsplit(".", 1)[-1]
|
||||
|
||||
|
||||
def int_constant(node: ast.AST | None) -> int | None:
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, int):
|
||||
return node.value
|
||||
if isinstance(node, ast.Name):
|
||||
return HTTP_STATUS_NAMES.get(node.id)
|
||||
if isinstance(node, ast.Attribute):
|
||||
return HTTP_STATUS_NAMES.get(node.attr)
|
||||
return None
|
||||
|
||||
|
||||
def string_constant(node: ast.AST | None) -> str | None:
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
return node.value
|
||||
return None
|
||||
|
||||
|
||||
def keyword_value(call: ast.Call, *names: str) -> ast.AST | None:
|
||||
for keyword in call.keywords:
|
||||
if keyword.arg in names:
|
||||
return keyword.value
|
||||
return None
|
||||
|
||||
|
||||
def is_probable_model_name(name: str) -> bool:
|
||||
return bool(name) and name[0].isupper()
|
||||
|
||||
|
||||
def model_name_from_schema_expr(node: ast.AST | None) -> str | None:
|
||||
if node is None:
|
||||
return None
|
||||
|
||||
if isinstance(node, ast.Subscript):
|
||||
value_name = dotted_name(node.value)
|
||||
if value_name and value_name.endswith(".models"):
|
||||
# register_response_schema_models stores schemas by model name; both
|
||||
# ns.models[Model.__name__] and ns.models["Model"] appear in controllers.
|
||||
key = node.slice
|
||||
if isinstance(key, ast.Attribute) and key.attr == "__name__":
|
||||
return leaf_name(key.value)
|
||||
return string_constant(key)
|
||||
|
||||
if isinstance(node, ast.Call):
|
||||
func_name = dotted_name(node.func)
|
||||
if func_name and func_name.endswith(".model"):
|
||||
return string_constant(node.args[0] if node.args else keyword_value(node, "name"))
|
||||
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(decorator.func)
|
||||
if not func_name or not func_name.endswith(".response"):
|
||||
return None
|
||||
|
||||
status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status")
|
||||
status = int_constant(status_expr)
|
||||
if status is None:
|
||||
return None
|
||||
|
||||
schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None
|
||||
schema_expr = keyword_value(decorator, "model", "schema") or schema_expr
|
||||
|
||||
return DocumentedResponse(
|
||||
status=status,
|
||||
model=model_name_from_schema_expr(schema_expr),
|
||||
line=decorator.lineno,
|
||||
)
|
||||
|
||||
|
||||
def route_from_decorator(decorator: ast.expr) -> str | None:
|
||||
if not isinstance(decorator, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(decorator.func)
|
||||
if not func_name or not func_name.endswith(".route"):
|
||||
return None
|
||||
|
||||
return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path"))
|
||||
|
||||
|
||||
def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]:
|
||||
return [route for decorator in decorators if (route := route_from_decorator(decorator))]
|
||||
|
||||
|
||||
def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]:
|
||||
docs: dict[int, DocumentedResponse] = {}
|
||||
for decorator in decorators:
|
||||
response = documented_response_from_decorator(decorator)
|
||||
if response and 200 <= response.status < 300:
|
||||
docs[response.status] = response
|
||||
return docs
|
||||
|
||||
|
||||
def model_name_from_model_validate_call(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate":
|
||||
return leaf_name(node.func.value)
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_constructor_call(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id):
|
||||
return node.func.id
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_model_dump(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump":
|
||||
return None
|
||||
|
||||
dumped_value = node.func.value
|
||||
if isinstance(dumped_value, ast.Call):
|
||||
return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def model_name_from_model_value(node: ast.AST) -> str | None:
|
||||
return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node)
|
||||
|
||||
|
||||
def model_name_from_dump_response(node: ast.AST) -> str | None:
|
||||
if not isinstance(node, ast.Call):
|
||||
return None
|
||||
|
||||
func_name = dotted_name(node.func)
|
||||
if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")):
|
||||
return None
|
||||
|
||||
model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model")
|
||||
if isinstance(model_expr, ast.Name):
|
||||
return model_expr.id
|
||||
return None
|
||||
|
||||
|
||||
def actual_kind_from_expr(
|
||||
expr: ast.AST | None, variable_models: dict[str, str] | None = None
|
||||
) -> tuple[ActualKind, str | None]:
|
||||
if expr is None:
|
||||
return "none", None
|
||||
|
||||
dump_response_model = model_name_from_dump_response(expr)
|
||||
if dump_response_model:
|
||||
return "model", dump_response_model
|
||||
|
||||
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump":
|
||||
dumped_value = expr.func.value
|
||||
if isinstance(dumped_value, ast.Name) and variable_models:
|
||||
# A variable dump can match today, but it bypasses dump_response and
|
||||
# is easier to drift; keep it visible as refactorable.
|
||||
model_name = variable_models.get(dumped_value.id)
|
||||
if model_name:
|
||||
return "model_dump_variable", model_name
|
||||
|
||||
model_dump_model = model_name_from_model_dump(expr)
|
||||
if model_dump_model:
|
||||
return "model", model_dump_model
|
||||
|
||||
if isinstance(expr, ast.Constant):
|
||||
if expr.value is None:
|
||||
return "none", None
|
||||
if expr.value == "":
|
||||
return "empty", None
|
||||
return "raw_value", None
|
||||
|
||||
if isinstance(expr, ast.Dict):
|
||||
return "raw_dict", None
|
||||
|
||||
if isinstance(expr, ast.List):
|
||||
return "raw_list", None
|
||||
|
||||
return "unknown", None
|
||||
|
||||
|
||||
def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse:
|
||||
status: int | None = 200
|
||||
body_expr = return_node.value
|
||||
|
||||
if isinstance(return_node.value, ast.Tuple) and return_node.value.elts:
|
||||
body_expr = return_node.value.elts[0]
|
||||
if len(return_node.value.elts) >= 2:
|
||||
# Dynamic statuses are not safe to coerce to 200; classify them as unknown.
|
||||
status = int_constant(return_node.value.elts[1])
|
||||
|
||||
kind, model = actual_kind_from_expr(body_expr, variable_models)
|
||||
return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno)
|
||||
|
||||
|
||||
def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]:
|
||||
"""Yield method body nodes while ignoring nested function/class scopes."""
|
||||
|
||||
stack: list[ast.AST] = list(reversed(method.body))
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
|
||||
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef):
|
||||
continue
|
||||
|
||||
stack.extend(reversed(list(ast.iter_child_nodes(node))))
|
||||
|
||||
|
||||
def target_names(target: ast.AST) -> Iterable[str]:
|
||||
if isinstance(target, ast.Name):
|
||||
yield target.id
|
||||
elif isinstance(target, ast.Tuple | ast.List):
|
||||
for item in target.elts:
|
||||
yield from target_names(item)
|
||||
|
||||
|
||||
def record_assignment(
|
||||
assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None
|
||||
) -> None:
|
||||
for target in targets:
|
||||
if model_name is None:
|
||||
# Once a name receives an unknown value, later model_dump() calls on it
|
||||
# are no longer a reliable signal for the returned schema.
|
||||
assignments[target].add_unknown()
|
||||
else:
|
||||
assignments[target].add_known(model_name)
|
||||
|
||||
|
||||
def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]:
|
||||
"""Infer local variables that are unambiguously assigned one response model."""
|
||||
|
||||
assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary)
|
||||
|
||||
for node in iter_method_nodes(method):
|
||||
match node:
|
||||
case ast.Assign(targets=targets, value=value):
|
||||
record_assignment(
|
||||
assignments,
|
||||
(name for target in targets for name in target_names(target)),
|
||||
model_name_from_model_value(value),
|
||||
)
|
||||
case ast.AnnAssign(target=target, value=value) if value is not None:
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target):
|
||||
# Mutation and loop targets overwrite prior values with runtime-dependent data.
|
||||
record_assignment(assignments, target_names(target), None)
|
||||
case ast.With(items=items) | ast.AsyncWith(items=items):
|
||||
for item in items:
|
||||
if item.optional_vars is not None:
|
||||
record_assignment(assignments, target_names(item.optional_vars), None)
|
||||
case ast.ExceptHandler(name=name) if name:
|
||||
assignments[name].add_unknown()
|
||||
case ast.NamedExpr(target=target, value=value):
|
||||
record_assignment(assignments, target_names(target), model_name_from_model_value(value))
|
||||
|
||||
return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None}
|
||||
|
||||
|
||||
def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]:
|
||||
"""Extract statically visible 2xx returns from one controller method.
|
||||
|
||||
The analysis is deliberately shallow and conservative:
|
||||
|
||||
1. Walk only the method's own body, skipping nested functions/classes.
|
||||
2. Infer local variables that are assigned exactly one recognizable response
|
||||
model, so ``response.model_dump()`` can still be connected to its schema.
|
||||
3. Treat any later unknown assignment, mutation target, loop target, context
|
||||
manager binding, or exception binding as invalidating that variable.
|
||||
4. For each top-level return path, split Flask-style ``(body, status)``
|
||||
tuples, classify the body expression, and keep non-literal statuses as
|
||||
``None`` so the classifier reports them as unknown instead of assuming 200.
|
||||
5. Drop non-2xx literal statuses, since response contracts here only compare
|
||||
successful response schemas.
|
||||
"""
|
||||
|
||||
variable_models = variable_model_assignments_for_method(method)
|
||||
responses: list[ActualResponse] = []
|
||||
for node in iter_method_nodes(method):
|
||||
if isinstance(node, ast.Return):
|
||||
responses.append(actual_response_from_return(node, variable_models))
|
||||
return [response for response in responses if response.status is None or 200 <= response.status < 300]
|
||||
|
||||
|
||||
def display_path(file_path: Path, repo_root: Path) -> str:
|
||||
try:
|
||||
return str(file_path.relative_to(repo_root))
|
||||
except ValueError:
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def classify_method(
|
||||
*,
|
||||
actual_responses: Sequence[ActualResponse],
|
||||
class_name: str,
|
||||
documented_responses: dict[int, DocumentedResponse],
|
||||
file_path: Path,
|
||||
method: MethodNode,
|
||||
repo_root: Path,
|
||||
route: str,
|
||||
) -> ContractCheck:
|
||||
documented_summary = {status: response.model for status, response in sorted(documented_responses.items())}
|
||||
context = ContractCheckContext(
|
||||
file=display_path(file_path, repo_root),
|
||||
class_name=class_name,
|
||||
method=method.name,
|
||||
route=route,
|
||||
line=method.lineno,
|
||||
documented=documented_summary,
|
||||
)
|
||||
|
||||
if not actual_responses:
|
||||
return context.build("unknown", "no statically visible 2xx return", [])
|
||||
|
||||
unknown_reasons: list[str] = []
|
||||
refactorable_reasons: list[str] = []
|
||||
|
||||
for actual in actual_responses:
|
||||
if actual.status is None:
|
||||
unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status")
|
||||
continue
|
||||
|
||||
documented = documented_responses.get(actual.status)
|
||||
|
||||
if actual.status in NO_BODY_STATUSES:
|
||||
# No-body statuses are contract violations even when the schema names
|
||||
# would otherwise match, because clients should not expect a payload.
|
||||
if documented is not None and documented.model is not None:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} is a no-body response but documents {documented.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}:
|
||||
no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno)
|
||||
return context.mismatch(
|
||||
f"status {actual.status} is a no-body response but returns {actual.kind}",
|
||||
no_body_doc,
|
||||
actual,
|
||||
)
|
||||
if actual.kind == "unknown":
|
||||
unknown_reasons.append(f"status {actual.status} returns unknown body expression")
|
||||
continue
|
||||
|
||||
if documented is None:
|
||||
unknown_reasons.append(f"status {actual.status} has no @response doc")
|
||||
continue
|
||||
|
||||
if documented.model is None:
|
||||
unknown_reasons.append(f"status {actual.status} response doc has no schema model")
|
||||
continue
|
||||
|
||||
if actual.kind == "model_dump_variable" and actual.model is not None:
|
||||
if documented.model != actual.model:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} documents {documented.model} but returns {actual.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
# The schema matches, but this path still deserves cleanup because
|
||||
# dump_response is the contract-aware serialization helper.
|
||||
refactorable_reasons.append(
|
||||
f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response"
|
||||
)
|
||||
continue
|
||||
|
||||
if actual.kind != "model" or actual.model is None:
|
||||
unknown_reasons.append(f"status {actual.status} returns {actual.kind}")
|
||||
continue
|
||||
|
||||
if documented.model != actual.model:
|
||||
return context.mismatch(
|
||||
f"status {actual.status} documents {documented.model} but returns {actual.model}",
|
||||
documented,
|
||||
actual,
|
||||
)
|
||||
|
||||
if unknown_reasons:
|
||||
# Unknown beats refactorable: if any return path is ambiguous, do not
|
||||
# imply the endpoint is merely a cleanup candidate.
|
||||
return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses)
|
||||
|
||||
if refactorable_reasons:
|
||||
return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses)
|
||||
|
||||
return context.build(
|
||||
"valid",
|
||||
"documented response schema matches statically visible return schema",
|
||||
actual_responses,
|
||||
)
|
||||
|
||||
|
||||
def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]:
|
||||
for path in paths:
|
||||
if path.is_file() and path.suffix == ".py":
|
||||
yield path
|
||||
elif path.is_dir():
|
||||
yield from sorted(child for child in path.rglob("*.py") if child.is_file())
|
||||
|
||||
|
||||
def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]:
|
||||
module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path))
|
||||
checks: list[ContractCheck] = []
|
||||
|
||||
for node in module.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
|
||||
class_routes = routes_from_decorators(node.decorator_list)
|
||||
class_documented = response_docs_from_decorators(node.decorator_list)
|
||||
|
||||
for item in node.body:
|
||||
if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS:
|
||||
continue
|
||||
|
||||
routes = routes_from_decorators(item.decorator_list) or class_routes
|
||||
if not routes:
|
||||
continue
|
||||
|
||||
documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)}
|
||||
# Method-level @response decorators override class-level defaults for
|
||||
# the same status code, matching Flask-RESTX's common controller style.
|
||||
actual = actual_responses_for_method(item)
|
||||
for route in routes:
|
||||
checks.append(
|
||||
classify_method(
|
||||
actual_responses=actual,
|
||||
class_name=node.name,
|
||||
documented_responses=documented,
|
||||
file_path=file_path,
|
||||
method=item,
|
||||
repo_root=repo_root,
|
||||
route=route,
|
||||
)
|
||||
)
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def as_jsonable(check: ContractCheck) -> dict[str, Any]:
|
||||
data = asdict(check)
|
||||
data["documented"] = {str(status): model for status, model in check.documented.items()}
|
||||
return data
|
||||
|
||||
|
||||
def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None:
|
||||
counts = Counter(check.classification for check in checks)
|
||||
sys.stdout.write(
|
||||
"Response contract lint: "
|
||||
f"{counts['valid']} valid, "
|
||||
f"{counts['mismatch']} mismatch, "
|
||||
f"{counts['refactorable']} refactorable, "
|
||||
f"{counts['unknown']} unknown\n"
|
||||
)
|
||||
|
||||
for classification in ("mismatch", "refactorable", "unknown", "valid"):
|
||||
filtered = [check for check in checks if check.classification == classification]
|
||||
if classification == "valid" and not include_valid:
|
||||
continue
|
||||
if not filtered:
|
||||
continue
|
||||
|
||||
sys.stdout.write(f"\n{classification.upper()}:\n")
|
||||
for check in filtered:
|
||||
location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}"
|
||||
sys.stdout.write(f"- {location}: {check.reason}\n")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"paths",
|
||||
nargs="*",
|
||||
help="Files or directories to lint. Defaults to Flask controller directories.",
|
||||
)
|
||||
parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.")
|
||||
parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.")
|
||||
parser.add_argument(
|
||||
"--fail-on-mismatch",
|
||||
action="store_true",
|
||||
help="Treat mismatched response contracts as failures. By default this linter is report-only.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fail-on-unknown",
|
||||
action="store_true",
|
||||
help="Treat unknown route methods as failures. By default this linter is report-only.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
api_root = Path(__file__).resolve().parents[1]
|
||||
repo_root = api_root.parent
|
||||
raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS)
|
||||
paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)]
|
||||
|
||||
checks: list[ContractCheck] = []
|
||||
for file_path in iter_controller_files(paths):
|
||||
checks.extend(checks_for_file(file_path.resolve(), repo_root))
|
||||
|
||||
checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method))
|
||||
|
||||
if args.json:
|
||||
grouped = defaultdict(list)
|
||||
for check in checks:
|
||||
grouped[check.classification].append(as_jsonable(check))
|
||||
sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n")
|
||||
else:
|
||||
print_text_report(checks, include_valid=bool(args.include_valid))
|
||||
|
||||
has_mismatch = any(check.classification == "mismatch" for check in checks)
|
||||
has_unknown = any(check.classification == "unknown" for check in checks)
|
||||
return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@ -457,14 +457,16 @@ def init_app(app: DifyApp):
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here."
|
||||
join_timeout_ms = dify_config.PUBSUB_LISTENER_JOIN_TIMEOUT_MS
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client)
|
||||
return ShardedRedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams":
|
||||
return StreamsBroadcastChannel(
|
||||
_pubsub_redis_client,
|
||||
retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS,
|
||||
join_timeout_ms=join_timeout_ms,
|
||||
)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client)
|
||||
return RedisBroadcastChannel(_pubsub_redis_client, join_timeout_ms=join_timeout_ms)
|
||||
|
||||
|
||||
def redis_fallback[T](default_return: T | None = None): # type: ignore
|
||||
|
||||
@ -53,25 +53,27 @@ def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
"""Recursively convert non-serializable values."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, (bool, int, float, str)):
|
||||
return value
|
||||
if isinstance(value, Segment):
|
||||
# Convert Segment to its underlying value
|
||||
return _convert_value(value.value)
|
||||
if isinstance(value, File):
|
||||
# Convert File to dict
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
# Convert Pydantic model to dict
|
||||
return _convert_value(value.model_dump(mode="json"))
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_value(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_convert_value(item) for item in value]
|
||||
# Fallback to string representation for unknown types
|
||||
return str(value)
|
||||
match value:
|
||||
case None:
|
||||
return None
|
||||
case bool() | int() | float() | str():
|
||||
return value
|
||||
case Segment():
|
||||
# Convert Segment to its underlying value
|
||||
return _convert_value(value.value)
|
||||
case File():
|
||||
# Convert File to dict
|
||||
return value.to_dict()
|
||||
case BaseModel():
|
||||
# Convert Pydantic model to dict
|
||||
return _convert_value(value.model_dump(mode="json"))
|
||||
case dict():
|
||||
return {k: _convert_value(v) for k, v in value.items()}
|
||||
case list() | tuple():
|
||||
return [_convert_value(item) for item in value]
|
||||
case _:
|
||||
# Fallback to string representation for unknown types
|
||||
return str(value)
|
||||
|
||||
try:
|
||||
converted = _convert_value(obj)
|
||||
@ -104,15 +106,15 @@ class DefaultNodeOTelParser:
|
||||
|
||||
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
|
||||
|
||||
node_type = node.node_type
|
||||
if node_type == BuiltinNodeTypes.LLM:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
|
||||
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
|
||||
elif node_type == BuiltinNodeTypes.TOOL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
|
||||
else:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
match node.node_type:
|
||||
case BuiltinNodeTypes.LLM:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
|
||||
case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
|
||||
case BuiltinNodeTypes.TOOL:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
|
||||
case _:
|
||||
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
|
||||
|
||||
# Extract inputs and outputs from result_event
|
||||
if result_event and result_event.node_run_result:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user