diff --git a/.github/workflows/pyrefly-type-coverage-comment.yml b/.github/workflows/pyrefly-type-coverage-comment.yml index 52c16f31537..2fe9aa591de 100644 --- a/.github/workflows/pyrefly-type-coverage-comment.yml +++ b/.github/workflows/pyrefly-type-coverage-comment.yml @@ -63,8 +63,8 @@ jobs: id: render run: | comment_body="$(uv run --directory api python libs/pyrefly_type_coverage.py \ - --base base_report.json \ - < pr_report.json)" + --base "$GITHUB_WORKSPACE/base_report.json" \ + < "$GITHUB_WORKSPACE/pr_report.json")" { echo "### Pyrefly Type Coverage" diff --git a/.github/workflows/pyrefly-type-coverage.yml b/.github/workflows/pyrefly-type-coverage.yml index eae8debf1a5..915e406b571 100644 --- a/.github/workflows/pyrefly-type-coverage.yml +++ b/.github/workflows/pyrefly-type-coverage.yml @@ -65,6 +65,9 @@ jobs: # Save structured data for the fork-PR comment workflow cp /tmp/pyrefly_report_pr.json pr_report.json cp /tmp/pyrefly_report_base.json base_report.json + # Keep fork-PR comments correct while the trusted workflow_run job is + # still using the default-branch renderer, which resolves --base from api/. + cp /tmp/pyrefly_report_base.json api/base_report.json - name: Save PR number run: | @@ -77,6 +80,7 @@ jobs: path: | pr_report.json base_report.json + api/base_report.json pr_number.txt - name: Comment PR with type coverage diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 4ce121ba605..c40ca5c1eac 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -47,6 +47,10 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: uv run --directory api --dev lint-imports + - name: Run Response Contract Linter + if: steps.changed-files.outputs.any_changed == 'true' + run: uv run --project api --dev python api/dev/lint_response_contracts.py --fail-on-mismatch + - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' run: make type-check-core diff --git a/Makefile b/Makefile index 9d3ac4ee478..be665e71231 100644 --- a/Makefile +++ b/Makefile @@ -75,13 +75,19 @@ check: @echo "✅ Code check complete" lint: - @echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..." + @echo "🔧 Running ruff format, check with fixes, response contract lint, import linter, and dotenv-linter..." @uv run --project api --dev ruff format ./api @uv run --project api --dev ruff check --fix ./api + @$(MAKE) api-contract-lint @uv run --directory api --dev lint-imports @uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example @echo "✅ Linting complete" +api-contract-lint: + @echo "🔎 Linting Flask response contracts..." + @uv run --project api --dev python api/dev/lint_response_contracts.py + @echo "✅ Response contract lint complete" + type-check: @echo "📝 Running type checks (pyrefly + mypy)..." @./dev/pyrefly-check-local $(PATH_TO_CHECK) @@ -191,6 +197,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" + @echo " make api-contract-lint - Check Flask response docs against returned schemas" @echo " make type-check - Run type checks (pyrefly, mypy)" @echo " make type-check-core - Run core type checks (pyrefly, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @@ -204,4 +211,4 @@ help: @echo " make build-push-all - Build and push all Docker images" # Phony targets -.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint type-check test test-all +.PHONY: build-web build-api push-web push-api build-all push-all build-push-all dev-setup prepare-docker prepare-web prepare-api dev-clean help format check lint api-contract-lint type-check test test-all diff --git a/api/AGENTS.md b/api/AGENTS.md index 4abd14e7c08..984322590b3 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -195,6 +195,7 @@ Before opening a PR / submitting: - Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic. - Services: coordinate repositories, providers, background tasks; keep side effects explicit. - Document non-obvious behaviour with concise docstrings and comments. +- For `204 No Content` responses, return an empty body only; never return a dict, model, or other payload. - For Flask-RESTX controller request, query, and response schemas, follow `controllers/API_SCHEMA_GUIDE.md`. In short: use Pydantic models, document GET query params with `query_params_from_model(...)`, register response DTOs with `register_response_schema_models(...)`, serialize response DTOs with `dump_response(...)`, diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index 84903733b5e..56a07a8b4a9 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -36,6 +36,24 @@ class FileInfo(BaseModel): size: int +def decode_remote_url(url: str, query_string: bytes | str = b"") -> str: + decoded_url = urllib.parse.unquote(url) + if isinstance(query_string, bytes): + raw_query = query_string.decode() + else: + raw_query = query_string + if not raw_query: + return decoded_url + + if decoded_url.endswith(("?", "&")): + separator = "" + elif urllib.parse.urlsplit(decoded_url).query: + separator = "&" + else: + separator = "?" + return f"{decoded_url}{separator}{raw_query}" + + def guess_file_info_from_response(response: httpx.Response): url = str(response.url) # Try to extract filename from URL diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 6463b022b5a..aca22d5c5ae 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -146,7 +146,7 @@ class BaseApiKeyResource(Resource): db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//api-keys") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index cfeaec4af91..bf8b57685fc 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -269,12 +269,12 @@ class AnnotationApi(Resource): "message": "annotation_ids are required if the parameter is provided.", }, 400 - result = AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids) - return result, 204 + AppAnnotationService.delete_app_annotations_in_batch(str(app_id), annotation_ids) + return "", 204 # If no annotation_ids are provided, handle clearing all annotations else: AppAnnotationService.clear_all_annotations(str(app_id)) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//annotations/export") @@ -335,7 +335,7 @@ class AnnotationUpdateDeleteApi(Resource): @edit_permission_required def delete(self, app_id: UUID, annotation_id: UUID): AppAnnotationService.delete_app_annotation(str(app_id), str(annotation_id)) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//annotations/batch-import") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 39b31de4c1e..63c3462b094 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -678,7 +678,7 @@ class AppApi(Resource): app_service = AppService() app_service.delete_app(app_model) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//copy") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index c7347933cb1..19780183664 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -29,9 +29,6 @@ from fields.conversation_fields import ( from fields.conversation_fields import ( ConversationWithSummaryPagination as ConversationWithSummaryPaginationResponse, ) -from fields.conversation_fields import ( - ResultResponse, -) from libs.datetime_utils import naive_utc_now, parse_time_range from libs.login import current_account_with_tenant, login_required from models import Conversation, EndUser, Message, MessageAnnotation @@ -77,7 +74,6 @@ register_schema_models( ConversationMessageDetailResponse, ConversationWithSummaryPaginationResponse, ConversationDetailResponse, - ResultResponse, CompletionConversationQuery, ChatConversationQuery, ) @@ -194,7 +190,7 @@ class CompletionConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 @console_ns.route("/apps//chat-conversations") @@ -347,7 +343,7 @@ class ChatConversationDetailApi(Resource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 def _get_conversation(app_model, conversation_id): diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index 41acf39541a..a1123a580ea 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -128,6 +128,6 @@ class TraceAppConfigApi(Resource): result = OpsService.delete_tracing_app_config(app_id=app_model.id, tracing_provider=args.tracing_provider) if not result: raise TracingConfigNotExist() - return {"result": "success"}, 204 + return "", 204 except Exception as e: raise BadRequest(str(e)) diff --git a/api/controllers/console/app/workflow_comment.py b/api/controllers/console/app/workflow_comment.py index f011f576fde..27009953b77 100644 --- a/api/controllers/console/app/workflow_comment.py +++ b/api/controllers/console/app/workflow_comment.py @@ -311,7 +311,7 @@ class WorkflowCommentDetailApi(Resource): user_id=current_user.id, ) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//workflow/comments//resolve") @@ -431,7 +431,7 @@ class WorkflowCommentReplyDetailApi(Resource): user_id=current_user.id, ) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/apps//workflow/comments/mention-users") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 2e3b8d2295a..3f0650389f6 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -93,4 +93,4 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_tenant_id, binding_id) - return {"result": "success"}, 204 + return "", 204 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b70572a1985..1828796e439 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -560,7 +560,7 @@ class DatasetApi(Resource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {"result": "success"}, 204 + return "", 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: @@ -898,7 +898,7 @@ class DatasetApiDeleteApi(Resource): db.session.delete(key) db.session.commit() - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//api-keys/") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index d98da48daa1..fabd61e6b09 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -504,7 +504,7 @@ class DatasetDocumentListApi(Resource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets/init") @@ -966,7 +966,7 @@ class DocumentApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//documents//download") @@ -1204,7 +1204,7 @@ class DocumentPauseApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot pause completed document.") - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//documents//processing/resume") @@ -1236,7 +1236,7 @@ class DocumentRecoverApi(DocumentResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Document is not in paused status.") - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//retry") @@ -1279,7 +1279,7 @@ class DocumentRetryApi(DocumentResource): # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//documents//rename") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2eaa64994e8..1d3bc96c1bd 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -251,7 +251,7 @@ class DatasetDocumentSegmentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segments(segment_ids, document, dataset) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets//documents//segment/") @@ -467,7 +467,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {"result": "success"}, 204 + return "", 204 @console_ns.route( @@ -754,7 +754,7 @@ class ChildChunkUpdateApi(Resource): SegmentService.delete_child_chunk(child_chunk, dataset) except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return {"result": "success"}, 204 + return "", 204 @setup_required @login_required diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 37cee1c17af..d1cdc15d0bf 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -218,7 +218,7 @@ class ExternalApiTemplateApi(Resource): raise Forbidden() ExternalDatasetService.delete_external_knowledge_api(current_tenant_id, external_knowledge_api_id) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/datasets/external-knowledge-api//use-check") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index c6d041fc59e..4de5f32fb87 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -1,14 +1,18 @@ from typing import Literal -from flask_restx import Resource, marshal_with +from flask_restx import Resource from werkzeug.exceptions import NotFound from controllers.common.controller_schemas import MetadataUpdatePayload -from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required -from fields.dataset_fields import dataset_metadata_fields +from fields.dataset_fields import ( + DatasetMetadataBuiltInFieldsResponse, + DatasetMetadataListResponse, + DatasetMetadataResponse, +) +from libs.helper import dump_response from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( @@ -22,7 +26,12 @@ from services.metadata_service import MetadataService register_schema_models( console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail ) -register_response_schema_models(console_ns, SimpleResultResponse) +register_response_schema_models( + console_ns, + DatasetMetadataBuiltInFieldsResponse, + DatasetMetadataListResponse, + DatasetMetadataResponse, +) @console_ns.route("/datasets//metadata") @@ -31,7 +40,7 @@ class DatasetMetadataCreateApi(Resource): @login_required @account_initialization_required @enterprise_license_required - @marshal_with(dataset_metadata_fields) + @console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataArgs.__name__]) def post(self, dataset_id): current_user, _ = current_account_with_tenant() @@ -44,18 +53,22 @@ class DatasetMetadataCreateApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) - return metadata, 201 + return dump_response(DatasetMetadataResponse, metadata), 201 @setup_required @login_required @account_initialization_required @enterprise_license_required + @console_ns.response( + 200, "Metadata retrieved successfully", console_ns.models[DatasetMetadataListResponse.__name__] + ) def get(self, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - return MetadataService.get_dataset_metadatas(dataset), 200 + metadata = MetadataService.get_dataset_metadatas(dataset) + return dump_response(DatasetMetadataListResponse, metadata), 200 @console_ns.route("/datasets//metadata/") @@ -64,7 +77,7 @@ class DatasetMetadataApi(Resource): @login_required @account_initialization_required @enterprise_license_required - @marshal_with(dataset_metadata_fields) + @console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) def patch(self, dataset_id, metadata_id): current_user, _ = current_account_with_tenant() @@ -79,7 +92,7 @@ class DatasetMetadataApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name) - return metadata, 200 + return dump_response(DatasetMetadataResponse, metadata), 200 @setup_required @login_required @@ -96,7 +109,8 @@ class DatasetMetadataApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) MetadataService.delete_metadata(dataset_id_str, metadata_id_str) - return {"result": "success"}, 204 + # Frontend callers only await success and invalidate metadata caches; no response body is consumed. + return "", 204 @console_ns.route("/datasets/metadata/built-in") @@ -105,9 +119,14 @@ class DatasetMetadataBuiltInFieldApi(Resource): @login_required @account_initialization_required @enterprise_license_required + @console_ns.response( + 200, + "Built-in fields retrieved successfully", + console_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__], + ) def get(self): built_in_fields = MetadataService.get_built_in_fields() - return {"fields": built_in_fields}, 200 + return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200 @console_ns.route("/datasets//metadata/built-in/") @@ -116,7 +135,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @login_required @account_initialization_required @enterprise_license_required - @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response(204, "Action completed successfully") def post(self, dataset_id, action: Literal["enable", "disable"]): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) @@ -130,7 +149,8 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): MetadataService.enable_built_in_field(dataset) case "disable": MetadataService.disable_built_in_field(dataset) - return {"result": "success"}, 200 + # Frontend callers only await success and invalidate metadata caches; no response body is consumed. + return "", 204 @console_ns.route("/datasets//documents/metadata") @@ -140,7 +160,10 @@ class DocumentMetadataEditApi(Resource): @account_initialization_required @enterprise_license_required @console_ns.expect(console_ns.models[MetadataOperationData.__name__]) - @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) + @console_ns.response( + 204, + "Documents metadata updated successfully", + ) def post(self, dataset_id): current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) @@ -153,4 +176,5 @@ class DocumentMetadataEditApi(Resource): MetadataService.update_documents_metadata(dataset, metadata_args) - return {"result": "success"}, 200 + # Frontend callers only await success and invalidate caches; no response body is consumed. + return "", 204 diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 826610acda9..ae32571219c 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -105,7 +105,7 @@ class ConversationApi(InstalledAppResource): except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 @console_ns.route( diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index fb8d35013f6..4ad3dbc85fc 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -270,7 +270,7 @@ class InstalledAppApi(InstalledAppResource): db.session.delete(installed_app) db.session.commit() - return {"result": "success", "message": "App uninstalled successfully"}, 204 + return "", 204 @console_ns.response(200, "Success", console_ns.models[SimpleResultMessageResponse.__name__]) def patch(self, installed_app): diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 51e73049a42..09f214bd2b0 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -76,4 +76,4 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index e53bb95c24e..35e62e3c7e9 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -204,4 +204,4 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) - return {"result": "success"}, 204 + return "", 204 diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 0e9a66c5794..654991900dd 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -3,12 +3,12 @@ from werkzeug.exceptions import Unauthorized from controllers.common.schema import register_response_schema_models from libs.login import current_account_with_tenant, current_user, login_required -from services.feature_service import FeatureModel, FeatureService, SystemFeatureModel +from services.feature_service import FeatureModel, FeatureService, LimitationModel, SystemFeatureModel from . import console_ns from .wraps import account_initialization_required, cloud_utm_record, setup_required -register_response_schema_models(console_ns, FeatureModel, SystemFeatureModel) +register_response_schema_models(console_ns, FeatureModel, LimitationModel, SystemFeatureModel) @console_ns.route("/features") @@ -28,7 +28,32 @@ class FeatureApi(Resource): """Get feature configuration for current tenant""" _, current_tenant_id = current_account_with_tenant() - return FeatureService.get_features(current_tenant_id).model_dump() + payload = FeatureService.get_features( + current_tenant_id, + exclude_vector_space=True, + ).model_dump() + payload.pop("vector_space", None) + return payload + + +@console_ns.route("/features/vector-space") +class FeatureVectorSpaceApi(Resource): + @console_ns.doc("get_tenant_feature_vector_space") + @console_ns.doc(description="Get vector-space usage and limit for current tenant") + @console_ns.response( + 200, + "Success", + console_ns.models[LimitationModel.__name__], + ) + @setup_required + @login_required + @account_initialization_required + @cloud_utm_record + def get(self): + """Get vector-space usage and limit for current tenant""" + _, current_tenant_id = current_account_with_tenant() + + return FeatureService.get_vector_space(current_tenant_id).model_dump() @console_ns.route("/system-features") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fd3ed78986c..19f1fd8aabc 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,6 +1,5 @@ -import urllib.parse - import httpx +from flask import request from flask_restx import Resource from pydantic import BaseModel, Field @@ -34,7 +33,7 @@ class GetRemoteFileInfo(Resource): @console_ns.response(200, "Success", console_ns.models[RemoteFileInfo.__name__]) @login_required def get(self, url: str): - decoded_url = urllib.parse.unquote(url) + decoded_url = helpers.decode_remote_url(url, request.query_string) resp = ssrf_proxy.head(decoded_url) if resp.status_code != httpx.codes.OK: resp = ssrf_proxy.get(decoded_url, timeout=3) diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index f00a7e5c79b..221cb3e406d 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -194,7 +194,7 @@ class ModelProviderCredentialApi(Resource): tenant_id=current_tenant_id, provider=provider, credential_id=args.credential_id ) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/workspaces/current/model-providers//credentials/switch") diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index a24d6d0f7d5..c23207e402c 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -259,7 +259,7 @@ class ModelProviderModelApi(Resource): tenant_id=tenant_id, provider=provider, model=args.model, model_type=args.model_type ) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/workspaces/current/model-providers//models/credentials") @@ -395,7 +395,7 @@ class ModelProviderModelCredentialApi(Resource): credential_id=args.credential_id, ) - return {"result": "success"}, 204 + return "", 204 @console_ns.route("/workspaces/current/model-providers//models/credentials/switch") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index afab582bf2c..58bdd0f6114 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -1,15 +1,19 @@ from typing import Literal from flask_login import current_user -from flask_restx import marshal from werkzeug.exceptions import NotFound from controllers.common.controller_schemas import MetadataUpdatePayload -from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_model, register_schema_models from controllers.service_api import service_api_ns from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check -from fields.dataset_fields import dataset_metadata_fields +from fields.dataset_fields import ( + DatasetMetadataActionResponse, + DatasetMetadataBuiltInFieldsResponse, + DatasetMetadataListResponse, + DatasetMetadataResponse, +) +from libs.helper import dump_response from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( DocumentMetadataOperation, @@ -27,7 +31,13 @@ register_schema_models( DocumentMetadataOperation, MetadataOperationData, ) -register_response_schema_models(service_api_ns, SimpleResultResponse) +register_response_schema_models( + service_api_ns, + DatasetMetadataActionResponse, + DatasetMetadataBuiltInFieldsResponse, + DatasetMetadataListResponse, + DatasetMetadataResponse, +) @service_api_ns.route("/datasets//metadata") @@ -43,6 +53,9 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): 404: "Dataset not found", } ) + @service_api_ns.response( + 201, "Metadata created successfully", service_api_ns.models[DatasetMetadataResponse.__name__] + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" @@ -55,7 +68,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) - return marshal(metadata, dataset_metadata_fields), 201 + return dump_response(DatasetMetadataResponse, metadata), 201 @service_api_ns.doc("get_dataset_metadata") @service_api_ns.doc(description="Get all metadata for a dataset") @@ -67,13 +80,17 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): 404: "Dataset not found", } ) + @service_api_ns.response( + 200, "Metadata retrieved successfully", service_api_ns.models[DatasetMetadataListResponse.__name__] + ) def get(self, tenant_id, dataset_id): """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") - return MetadataService.get_dataset_metadatas(dataset), 200 + metadata = MetadataService.get_dataset_metadatas(dataset) + return dump_response(DatasetMetadataListResponse, metadata), 200 @service_api_ns.route("/datasets//metadata/") @@ -89,6 +106,9 @@ class DatasetMetadataServiceApi(DatasetApiResource): 404: "Dataset or metadata not found", } ) + @service_api_ns.response( + 200, "Metadata updated successfully", service_api_ns.models[DatasetMetadataResponse.__name__] + ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): """Update metadata name.""" @@ -102,7 +122,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name) - return marshal(metadata, dataset_metadata_fields), 200 + return dump_response(DatasetMetadataResponse, metadata), 200 @service_api_ns.doc("delete_dataset_metadata") @service_api_ns.doc(description="Delete metadata") @@ -114,6 +134,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): 404: "Dataset or metadata not found", } ) + @service_api_ns.response(204, "Metadata deleted successfully") @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): """Delete metadata.""" @@ -138,10 +159,15 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): 401: "Unauthorized - invalid API token", } ) + @service_api_ns.response( + 200, + "Built-in fields retrieved successfully", + service_api_ns.models[DatasetMetadataBuiltInFieldsResponse.__name__], + ) def get(self, tenant_id, dataset_id): """Get all built-in metadata fields.""" built_in_fields = MetadataService.get_built_in_fields() - return {"fields": built_in_fields}, 200 + return dump_response(DatasetMetadataBuiltInFieldsResponse, {"fields": built_in_fields}), 200 @service_api_ns.route("/datasets//metadata/built-in/") @@ -157,9 +183,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): } ) @service_api_ns.response( - 200, - "Action completed successfully", - service_api_ns.models[SimpleResultResponse.__name__], + 200, "Action completed successfully", service_api_ns.models[DatasetMetadataActionResponse.__name__] ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): @@ -175,7 +199,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): MetadataService.enable_built_in_field(dataset) case "disable": MetadataService.disable_built_in_field(dataset) - return {"result": "success"}, 200 + return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 @service_api_ns.route("/datasets//documents/metadata") @@ -194,7 +218,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): @service_api_ns.response( 200, "Documents metadata updated successfully", - service_api_ns.models[SimpleResultResponse.__name__], + service_api_ns.models[DatasetMetadataActionResponse.__name__], ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): @@ -209,4 +233,4 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): MetadataService.update_documents_metadata(dataset, metadata_args) - return {"result": "success"}, 200 + return dump_response(DatasetMetadataActionResponse, {"result": "success"}), 200 diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 31151b5f9f3..41f8ef53a5b 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -11,7 +11,7 @@ register_response_schema_models(service_api_ns, IndexInfoResponse) @service_api_ns.route("/") class IndexApi(Resource): @service_api_ns.response(200, "Success", service_api_ns.models[IndexInfoResponse.__name__]) - def get(self): + def get(self) -> dict[str, str]: return { "welcome": "Dify OpenAPI", "api_version": "v1", diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 6afa00c727c..a99adb391f9 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -136,7 +136,7 @@ class ConversationApi(WebApiResource): ConversationService.delete(app_model, conversation_id, end_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 @web_ns.route("/conversations//name") diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ff583acd5c4..e9f727097b7 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,5 @@ -import urllib.parse - import httpx +from flask import request from pydantic import BaseModel, Field, HttpUrl import services @@ -59,7 +58,7 @@ class RemoteFileInfoApi(WebApiResource): Raises: HTTPException: If the remote file cannot be accessed """ - decoded_url = urllib.parse.unquote(url) + decoded_url = helpers.decode_remote_url(url, request.query_string) resp = ssrf_proxy.head(decoded_url) if resp.status_code != httpx.codes.OK: # failed back to get method diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b1382efbd82..e307367b647 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -112,4 +112,4 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) - return ResultResponse(result="success").model_dump(mode="json"), 204 + return "", 204 diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index a2186be100c..8d4a5931b82 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,4 +1,5 @@ import json +from typing import override from core.agent.cot_agent_runner import CotAgentRunner from graphon.file import file_manager @@ -66,6 +67,7 @@ class CotChatAgentRunner(CotAgentRunner): return prompt_messages + @override def _organize_prompt_messages(self) -> list[PromptMessage]: """ Organize diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 51a30998ae2..b677fc6af4d 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,4 +1,5 @@ import json +from typing import override from core.agent.cot_agent_runner import CotAgentRunner from graphon.model_runtime.entities.message_entities import ( @@ -51,6 +52,7 @@ class CotCompletionAgentRunner(CotAgentRunner): return historic_prompt + @override def _organize_prompt_messages(self) -> list[PromptMessage]: """ Organize prompt messages diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index a3cc7983526..a06595ac164 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Sequence -from typing import Any +from typing import Any, override from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter @@ -23,6 +23,7 @@ class PluginAgentStrategy(BaseAgentStrategy): self.declaration = declaration self.meta_version = meta_version + @override def get_parameters(self) -> Sequence[AgentStrategyParameter]: return self.declaration.parameters @@ -34,6 +35,7 @@ class PluginAgentStrategy(BaseAgentStrategy): params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name)) return params + @override def _invoke( self, params: dict[str, Any], diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 7cb0c9a8d3f..4f3c74deea2 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -20,6 +20,7 @@ class AdvancedChatAppGenerateResponseConverter( AppGenerateResponseConverter[ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse] ): @classmethod + @override def convert_blocking_full_response( cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse ) -> dict[str, Any]: @@ -59,6 +60,7 @@ class AdvancedChatAppGenerateResponseConverter( return response @classmethod + @override def convert_blocking_simple_response( cls, blocking_response: ChatbotAppBlockingResponse | AdvancedChatPausedBlockingResponse ) -> dict[str, Any]: @@ -76,6 +78,7 @@ class AdvancedChatAppGenerateResponseConverter( return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, Any, None]: @@ -107,6 +110,7 @@ class AdvancedChatAppGenerateResponseConverter( yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, Any, None]: diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index 03bc0a91085..618509101aa 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. @@ -37,6 +38,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. @@ -54,6 +56,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -85,6 +88,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter[Chatbot yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 26efcbfafd1..0869f0405b7 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking full response. @@ -37,6 +38,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): """ Convert blocking simple response. @@ -54,6 +56,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -85,6 +88,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter[ChatbotAppBl yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/completion/generate_response_converter.py b/api/core/app/apps/completion/generate_response_converter.py index ad978f58e0a..806575c2561 100644 --- a/api/core/app/apps/completion/generate_response_converter.py +++ b/api/core/app/apps/completion/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from pydantic import JsonValue @@ -16,6 +16,7 @@ from core.app.entities.task_entities import ( class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[CompletionAppBlockingResponse]): @classmethod + @override def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking full response. @@ -36,6 +37,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple return response @classmethod + @override def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): """ Convert blocking simple response. @@ -53,6 +55,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple return response @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -83,6 +86,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter[Comple yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 67fc016cba2..0b97809bf3a 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -21,6 +23,7 @@ class MessageBasedAppQueueManager(AppQueueManager): self._app_mode = app_mode self._message_id = str(message_id) + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue diff --git a/api/core/app/apps/pipeline/pipeline_queue_manager.py b/api/core/app/apps/pipeline/pipeline_queue_manager.py index 151b50f238e..c34b51c98c7 100644 --- a/api/core/app/apps/pipeline/pipeline_queue_manager.py +++ b/api/core/app/apps/pipeline/pipeline_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +21,7 @@ class PipelineQueueManager(AppQueueManager): self._app_mode = app_mode + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: """ Publish event to queue diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 9985e2d2752..fcdd1465d4f 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,3 +1,5 @@ +from typing import override + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom @@ -19,6 +21,7 @@ class WorkflowAppQueueManager(AppQueueManager): self._app_mode = app_mode + @override def _publish(self, event: AppQueueEvent, pub_from: PublishFrom): """ Publish event to queue diff --git a/api/core/app/apps/workflow/generate_response_converter.py b/api/core/app/apps/workflow/generate_response_converter.py index c390ad30c94..b286c53048d 100644 --- a/api/core/app/apps/workflow/generate_response_converter.py +++ b/api/core/app/apps/workflow/generate_response_converter.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any, cast +from typing import Any, cast, override from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter from core.app.entities.task_entities import ( @@ -18,6 +18,7 @@ class WorkflowAppGenerateResponseConverter( AppGenerateResponseConverter[WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse] ): @classmethod + @override def convert_blocking_full_response( cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse ) -> dict[str, Any]: @@ -29,6 +30,7 @@ class WorkflowAppGenerateResponseConverter( return dict(blocking_response.model_dump()) @classmethod + @override def convert_blocking_simple_response( cls, blocking_response: WorkflowAppBlockingResponse | WorkflowAppPausedBlockingResponse ) -> dict[str, Any]: @@ -40,6 +42,7 @@ class WorkflowAppGenerateResponseConverter( return cls.convert_blocking_full_response(blocking_response) @classmethod + @override def convert_stream_full_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: @@ -73,6 +76,7 @@ class WorkflowAppGenerateResponseConverter( yield response_chunk @classmethod + @override def convert_stream_simple_response( cls, stream_response: Generator[AppStreamResponse, None, None] ) -> Generator[dict[str, Any] | str, None, None]: diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py index a6c6e74f068..fd4f350e25d 100644 --- a/api/core/app/file_access/controller.py +++ b/api/core/app/file_access/controller.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Callable +from typing import override from sqlalchemy import and_, or_, select from sqlalchemy.orm import Session @@ -31,9 +32,11 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): ) -> None: self._scope_getter = scope_getter + @override def current_scope(self) -> FileAccessScope | None: return self._scope_getter() + @override def apply_upload_file_filters( self, stmt: Select[tuple[UploadFile]], @@ -62,6 +65,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): ) ) + @override def apply_tool_file_filters( self, stmt: Select[tuple[ToolFile]], @@ -78,6 +82,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + @override def get_upload_file( self, *, @@ -95,6 +100,7 @@ class DatabaseFileAccessController(FileAccessControllerProtocol): ) return session.scalar(stmt) + @override def get_tool_file( self, *, diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d5e6b04a4a6..9125236af61 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -8,6 +8,7 @@ scope updates that matter to chat applications. """ import logging +from typing import override from core.workflow.system_variables import SystemVariableKey, get_system_text from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID @@ -23,9 +24,11 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): super().__init__() self._conversation_variable_updater = conversation_variable_updater + @override def on_graph_start(self) -> None: pass + @override def on_event(self, event: GraphEngineEvent) -> None: if not isinstance(event, NodeRunVariableUpdatedEvent): return @@ -44,5 +47,6 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) + @override def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 9811f9f8308..d6517218992 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Annotated, Literal, Self +from typing import Annotated, Literal, Self, override from pydantic import BaseModel, Field from sqlalchemy import Engine @@ -83,6 +83,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): def _get_repo(self) -> APIWorkflowRunRepository: return DifyAPIRepositoryFactory.create_api_workflow_run_repository(self._session_maker) + @override def on_graph_start(self) -> None: """ Called when graph execution starts. @@ -92,6 +93,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): """ pass + @override def on_event(self, event: GraphEngineEvent) -> None: """ Called for every event emitted by the engine. @@ -132,6 +134,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer): pause_reasons=event.reasons, ) + @override def on_graph_end(self, error: Exception | None) -> None: """ Called when graph execution ends. diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 1a79a9f843e..3e28303a7d3 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,3 +1,5 @@ +from typing import override + from graphon.graph_engine.layers import GraphEngineLayer from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent @@ -9,9 +11,11 @@ class SuspendLayer(GraphEngineLayer): super().__init__() self._paused = False + @override def on_graph_start(self): self._paused = False + @override def on_event(self, event: GraphEngineEvent): """ Handle the paused event, stash runtime state into storage and wait for resume. @@ -19,6 +23,7 @@ class SuspendLayer(GraphEngineLayer): if isinstance(event, GraphRunPausedEvent): self._paused = True + @override def on_graph_end(self, error: Exception | None): """ """ self._paused = False diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index bb9fc1b6fa7..094c21944df 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import ClassVar +from typing import ClassVar, override from apscheduler.schedulers.background import BackgroundScheduler # type: ignore @@ -63,6 +63,7 @@ class TimeSliceLayer(GraphEngineLayer): except Exception: logger.exception("scheduler error during check if the workflow need to be suspended") + @override def on_graph_start(self): """ Start timer to check if the workflow need to be suspended. @@ -78,9 +79,11 @@ class TimeSliceLayer(GraphEngineLayer): id=self.schedule_id, ) + @override def on_event(self, event: GraphEngineEvent): pass + @override def on_graph_end(self, error: Exception | None) -> None: self.stopped = True # remove the scheduler diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index b60fe82ffe7..65b8af67065 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -1,6 +1,6 @@ import logging from datetime import UTC, datetime -from typing import Any, ClassVar +from typing import Any, ClassVar, override from pydantic import TypeAdapter @@ -37,9 +37,11 @@ class TriggerPostLayer(GraphEngineLayer): self.start_time = start_time self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity + @override def on_graph_start(self): pass + @override def on_event(self, event: GraphEngineEvent): """ Update trigger log with success or failure. @@ -82,5 +84,6 @@ class TriggerPostLayer(GraphEngineLayer): repo.update(trigger_log) session.commit() + @override def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 587f7002866..90fdf410227 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -7,7 +7,7 @@ import os import time import urllib.parse from collections.abc import Generator -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, override from configs import dify_config from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol @@ -40,15 +40,19 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): self._file_access_controller = file_access_controller @property + @override def multimodal_send_format(self) -> str: return dify_config.MULTIMODAL_SEND_FORMAT + @override def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: return graphon_ssrf_proxy.get(url, follow_redirects=follow_redirects) + @override def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) + @override def load_file_bytes(self, *, file: File) -> bytes: storage_key = self._resolve_storage_key(file=file) data = storage.load(storage_key, stream=False) @@ -56,6 +60,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): raise ValueError(f"file {storage_key} is not a bytes object") return data + @override def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: if file.transfer_method == FileTransferMethod.REMOTE_URL: return file.remote_url @@ -86,6 +91,7 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): ) return None + @override def resolve_upload_file_url( self, *, @@ -101,10 +107,12 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): query["as_attachment"] = "true" return f"{url}?{urllib.parse.urlencode(query)}" + @override def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + @override def verify_preview_signature( self, *, diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index c5dba65232d..619590c81ee 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -12,7 +12,7 @@ state. from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime -from typing import Any, Union +from typing import Any, Union, override from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.helper.trace_id_helper import ParentTraceContext @@ -98,12 +98,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): # ------------------------------------------------------------------ # GraphEngineLayer lifecycle # ------------------------------------------------------------------ + @override def on_graph_start(self) -> None: self._workflow_execution = None self._node_execution_cache.clear() self._node_snapshots.clear() self._node_sequence = 0 + @override def on_event(self, event: GraphEngineEvent) -> None: match event: case GraphRunStartedEvent(): @@ -131,6 +133,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): case NodeRunPauseRequestedEvent(): self._handle_node_pause_requested(event) + @override def on_graph_end(self, error: Exception | None) -> None: return diff --git a/api/core/datasource/local_file/local_file_plugin.py b/api/core/datasource/local_file/local_file_plugin.py index 070a89cb2fe..7d20d62cbbc 100644 --- a/api/core/datasource/local_file/local_file_plugin.py +++ b/api/core/datasource/local_file/local_file_plugin.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( @@ -22,8 +24,10 @@ class LocalFileDatasourcePlugin(DatasourcePlugin): self.tenant_id = tenant_id self.plugin_unique_identifier = plugin_unique_identifier + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.LOCAL_FILE + @override def get_icon_url(self, tenant_id: str) -> str: return self.icon diff --git a/api/core/datasource/local_file/local_file_provider.py b/api/core/datasource/local_file/local_file_provider.py index b2b6f51dd38..6b6f78b33d7 100644 --- a/api/core/datasource/local_file/local_file_provider.py +++ b/api/core/datasource/local_file/local_file_provider.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -19,12 +19,14 @@ class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderContro self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider """ return DatasourceProviderType.LOCAL_FILE + @override def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider diff --git a/api/core/datasource/online_document/online_document_plugin.py b/api/core/datasource/online_document/online_document_plugin.py index ce23da1e09c..2fbf575d553 100644 --- a/api/core/datasource/online_document/online_document_plugin.py +++ b/api/core/datasource/online_document/online_document_plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -67,5 +67,6 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/core/datasource/online_document/online_document_provider.py b/api/core/datasource/online_document/online_document_provider.py index a128b479f4d..f1f34c8ba19 100644 --- a/api/core/datasource/online_document/online_document_provider.py +++ b/api/core/datasource/online_document/online_document_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -17,6 +19,7 @@ class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderC self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/datasource/online_drive/online_drive_plugin.py b/api/core/datasource/online_drive/online_drive_plugin.py index 64715226ccd..a12226a1603 100644 --- a/api/core/datasource/online_drive/online_drive_plugin.py +++ b/api/core/datasource/online_drive/online_drive_plugin.py @@ -1,4 +1,5 @@ from collections.abc import Generator +from typing import override from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -67,5 +68,6 @@ class OnlineDriveDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.ONLINE_DRIVE diff --git a/api/core/datasource/online_drive/online_drive_provider.py b/api/core/datasource/online_drive/online_drive_provider.py index d0923ed807c..d4a6942d098 100644 --- a/api/core/datasource/online_drive/online_drive_provider.py +++ b/api/core/datasource/online_drive/online_drive_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -17,6 +19,7 @@ class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderCont self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/datasource/website_crawl/website_crawl_plugin.py b/api/core/datasource/website_crawl/website_crawl_plugin.py index 087ac65a7ad..c5c9b4c0f2e 100644 --- a/api/core/datasource/website_crawl/website_crawl_plugin.py +++ b/api/core/datasource/website_crawl/website_crawl_plugin.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping -from typing import Any +from typing import Any, override from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_runtime import DatasourceRuntime @@ -47,5 +47,6 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin): provider_type=provider_type, ) + @override def datasource_provider_type(self) -> str: return DatasourceProviderType.WEBSITE_CRAWL diff --git a/api/core/datasource/website_crawl/website_crawl_provider.py b/api/core/datasource/website_crawl/website_crawl_provider.py index 8c0f20ce2de..0dfdf3c0ddd 100644 --- a/api/core/datasource/website_crawl/website_crawl_provider.py +++ b/api/core/datasource/website_crawl/website_crawl_provider.py @@ -1,3 +1,5 @@ +from typing import override + from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType @@ -21,6 +23,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon self.plugin_unique_identifier = plugin_unique_identifier @property + @override def provider_type(self) -> DatasourceProviderType: """ returns the type of the provider diff --git a/api/core/helper/code_executor/javascript/javascript_code_provider.py b/api/core/helper/code_executor/javascript/javascript_code_provider.py index ae324b83a95..b9442fdd1eb 100644 --- a/api/core/helper/code_executor/javascript/javascript_code_provider.py +++ b/api/core/helper/code_executor/javascript/javascript_code_provider.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import override from core.helper.code_executor.code_executor import CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider @@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider class JavascriptCodeProvider(CodeNodeProvider): @staticmethod + @override def get_language() -> str: return CodeLanguage.JAVASCRIPT @classmethod + @override def get_default_code(cls) -> str: return dedent( """ diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py index e28f027a3aa..249e67666c6 100644 --- a/api/core/helper/code_executor/javascript/javascript_transformer.py +++ b/api/core/helper/code_executor/javascript/javascript_transformer.py @@ -1,10 +1,12 @@ from textwrap import dedent +from typing import override from core.helper.code_executor.template_transformer import TemplateTransformer class NodeJsTemplateTransformer(TemplateTransformer): @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" {cls._code_placeholder} diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py index 5e4807401ee..9cf5089f7b5 100644 --- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py +++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from textwrap import dedent -from typing import Any +from typing import Any, override from core.helper.code_executor.template_transformer import TemplateTransformer @@ -10,6 +10,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): _template_b64_placeholder: str = "{{template_b64}}" @classmethod + @override def transform_response(cls, response: str): """ Transform response to dict @@ -19,6 +20,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return {"result": cls.extract_result_str_from_response(response)} @classmethod + @override def assemble_runner_script(cls, code: str, inputs: Mapping[str, Any]) -> str: """ Override base class to use base64 encoding for template code. @@ -34,6 +36,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return script @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" import jinja2 @@ -61,6 +64,7 @@ class Jinja2TemplateTransformer(TemplateTransformer): return runner_script @classmethod + @override def get_preload_script(cls) -> str: preload_script = dedent(""" import jinja2 diff --git a/api/core/helper/code_executor/python3/python3_code_provider.py b/api/core/helper/code_executor/python3/python3_code_provider.py index 151bf0e201a..8157a477a1e 100644 --- a/api/core/helper/code_executor/python3/python3_code_provider.py +++ b/api/core/helper/code_executor/python3/python3_code_provider.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import override from core.helper.code_executor.code_executor import CodeLanguage from core.helper.code_executor.code_node_provider import CodeNodeProvider @@ -6,10 +7,12 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider class Python3CodeProvider(CodeNodeProvider): @staticmethod + @override def get_language() -> str: return CodeLanguage.PYTHON3 @classmethod + @override def get_default_code(cls) -> str: return dedent( """ diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py index ee866eeb816..07947ee7929 100644 --- a/api/core/helper/code_executor/python3/python3_transformer.py +++ b/api/core/helper/code_executor/python3/python3_transformer.py @@ -1,10 +1,12 @@ from textwrap import dedent +from typing import override from core.helper.code_executor.template_transformer import TemplateTransformer class Python3TemplateTransformer(TemplateTransformer): @classmethod + @override def get_runner_script(cls) -> str: runner_script = dedent(f""" {cls._code_placeholder} diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 9f167ca49c1..6ad08dfe178 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -1,7 +1,7 @@ import json from abc import ABC, abstractmethod from json import JSONDecodeError -from typing import Any +from typing import Any, override from extensions.ext_redis import redis_client @@ -47,6 +47,7 @@ class SingletonProviderCredentialsCache(ProviderCredentialsCache): provider_identity=provider_identity, ) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider_type = kwargs["provider_type"] @@ -61,6 +62,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): def __init__(self, tenant_id: str, provider: str, credential_id: str): super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + @override def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] provider = kwargs["provider"] diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 91e92712b7c..b2493934bfb 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -43,13 +43,16 @@ request_error = httpx.RequestError max_retries_exceeded_error = MaxRetriesExceededError -def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]: +def _create_proxy_mounts(verify: bool) -> dict[str, httpx.HTTPTransport]: + """Build per-scheme proxy transports with the same TLS policy as the SSRF client.""" return { "http://": httpx.HTTPTransport( proxy=dify_config.SSRF_PROXY_HTTP_URL, + verify=verify, ), "https://": httpx.HTTPTransport( proxy=dify_config.SSRF_PROXY_HTTPS_URL, + verify=verify, ), } @@ -64,7 +67,7 @@ def _build_ssrf_client(verify: bool) -> httpx.Client: if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: return httpx.Client( - mounts=_create_proxy_mounts(), + mounts=_create_proxy_mounts(verify=verify), verify=verify, limits=_SSRF_CLIENT_LIMITS, ) diff --git a/api/core/logging/filters.py b/api/core/logging/filters.py index dee1432363c..e0a3148d56a 100644 --- a/api/core/logging/filters.py +++ b/api/core/logging/filters.py @@ -2,6 +2,7 @@ import contextlib import logging +from typing import override import flask @@ -15,6 +16,7 @@ class TraceContextFilter(logging.Filter): Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id. """ + @override def filter(self, record: logging.LogRecord) -> bool: # Get trace context from OpenTelemetry trace_id, span_id = self._get_otel_context() @@ -54,6 +56,7 @@ class IdentityContextFilter(logging.Filter): Extracts tenant_id, user_id, and user_type from Flask-Login current_user. """ + @override def filter(self, record: logging.LogRecord) -> bool: identity = self._extract_identity() record.tenant_id = identity.get("tenant_id", "") diff --git a/api/core/logging/structured_formatter.py b/api/core/logging/structured_formatter.py index ae7be91c173..56ea7482427 100644 --- a/api/core/logging/structured_formatter.py +++ b/api/core/logging/structured_formatter.py @@ -3,7 +3,7 @@ import logging import traceback from datetime import UTC, datetime -from typing import Any, NotRequired, TypedDict +from typing import Any, NotRequired, TypedDict, override import orjson @@ -58,6 +58,7 @@ class StructuredJSONFormatter(logging.Formatter): super().__init__() self._service_name = service_name or dify_config.APPLICATION_NAME + @override def format(self, record: logging.LogRecord) -> str: log_dict = self._build_log_dict(record) try: diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 28165592fc5..ec9a1906f8a 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from pydantic import BaseModel, Field from sqlalchemy import select @@ -25,6 +25,7 @@ class ApiModeration(Moderation): name: str = "api" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -43,6 +44,7 @@ class ApiModeration(Moderation): if not extension: raise ValueError("API-based Extension not found. Please check it again.") + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -59,6 +61,7 @@ class ApiModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 7d80d3a53c8..339574556d9 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from typing import Any +from typing import Any, override from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult @@ -8,6 +8,7 @@ class KeywordsModeration(Moderation): name: str = "keywords" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -28,6 +29,7 @@ class KeywordsModeration(Moderation): if len(keywords_row_len) > 100: raise ValueError("the number of rows for the keywords must be less than 100") + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -49,6 +51,7 @@ class KeywordsModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 6e6e94502cc..4b7a08eb277 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, override from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult @@ -9,6 +9,7 @@ class OpenAIModeration(Moderation): name: str = "openai_moderation" @classmethod + @override def validate_config(cls, tenant_id: str, config: dict[str, Any]): """ Validate the incoming form config data. @@ -19,6 +20,7 @@ class OpenAIModeration(Moderation): """ cls._validate_inputs_and_outputs_config(config, True) + @override def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult: flagged = False preset_response = "" @@ -36,6 +38,7 @@ class OpenAIModeration(Moderation): flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response ) + @override def moderation_for_outputs(self, text: str) -> ModerationOutputsResult: flagged = False preset_response = "" diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py index 62573ba2f53..d555f4d9657 100644 --- a/api/core/plugin/impl/model_runtime.py +++ b/api/core/plugin/impl/model_runtime.py @@ -4,7 +4,7 @@ import hashlib import logging from collections.abc import Generator, Iterable, Sequence from threading import Lock -from typing import IO, Any, Literal, cast, overload +from typing import IO, Any, Literal, cast, overload, override from pydantic import ValidationError from redis import RedisError @@ -118,6 +118,7 @@ class PluginModelRuntime(ModelRuntime): self._provider_entities = None self._provider_entities_lock = Lock() + @override def fetch_model_providers(self) -> Sequence[ProviderEntity]: if self._provider_entities is not None: return self._provider_entities @@ -130,6 +131,7 @@ class PluginModelRuntime(ModelRuntime): return self._provider_entities + @override def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: provider_schema = self._get_provider_schema(provider) @@ -172,6 +174,7 @@ class PluginModelRuntime(ModelRuntime): mime_type = image_mime_types.get(extension, "image/png") return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + @override def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: plugin_id, provider_name = self._split_provider(provider) self.client.validate_provider_credentials( @@ -182,6 +185,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def validate_model_credentials( self, *, @@ -201,6 +205,7 @@ class PluginModelRuntime(ModelRuntime): credentials=credentials, ) + @override def get_model_schema( self, *, @@ -294,6 +299,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @override def invoke_llm( self, *, @@ -357,6 +363,7 @@ class PluginModelRuntime(ModelRuntime): stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @override def invoke_llm_with_structured_output( self, *, @@ -396,6 +403,7 @@ class PluginModelRuntime(ModelRuntime): stream=stream, ) + @override def get_llm_num_tokens( self, *, @@ -422,6 +430,7 @@ class PluginModelRuntime(ModelRuntime): tools=list(tools) if tools else None, ) + @override def invoke_text_embedding( self, *, @@ -443,6 +452,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def invoke_multimodal_embedding( self, *, @@ -464,6 +474,7 @@ class PluginModelRuntime(ModelRuntime): input_type=input_type, ) + @override def get_text_embedding_num_tokens( self, *, @@ -483,6 +494,7 @@ class PluginModelRuntime(ModelRuntime): texts=texts, ) + @override def invoke_rerank( self, *, @@ -508,6 +520,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_multimodal_rerank( self, *, @@ -533,6 +546,7 @@ class PluginModelRuntime(ModelRuntime): top_n=top_n, ) + @override def invoke_tts( self, *, @@ -554,6 +568,7 @@ class PluginModelRuntime(ModelRuntime): voice=voice, ) + @override def get_tts_model_voices( self, *, @@ -573,6 +588,7 @@ class PluginModelRuntime(ModelRuntime): language=language, ) + @override def invoke_speech_to_text( self, *, @@ -592,6 +608,7 @@ class PluginModelRuntime(ModelRuntime): file=file, ) + @override def invoke_moderation( self, *, diff --git a/api/dev/lint_response_contracts.py b/api/dev/lint_response_contracts.py new file mode 100644 index 00000000000..6cdb3e289c0 --- /dev/null +++ b/api/dev/lint_response_contracts.py @@ -0,0 +1,664 @@ +"""Lint Flask-RESTX response docs against statically visible response serializers. + +This checker intentionally stays conservative. It only reports a hard schema +mismatch when both sides are statically known for the same 2xx status code: +a documented ``@ns.response(..., Model)`` and an actual ``dump_response(Model, ...)`` +or ``Model.model_validate(...).model_dump()`` return. + +Raw dictionaries, raw lists, ``None`` responses, streaming helpers, missing +response schemas, and returns with non-literal status codes are classified as +unknown so reviewers can triage them without blocking unrelated work. The one +intentional non-schema mismatch is a known body/schema on a no-body status such +as 204, 205, or 304. +""" + +from __future__ import annotations + +import argparse +import ast +import json +import sys +from collections import Counter, defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import asdict, dataclass, field +from http import HTTPStatus +from pathlib import Path +from typing import Any, Literal + +HTTP_METHODS = {"delete", "get", "head", "options", "patch", "post", "put"} +NO_BODY_STATUSES = {HTTPStatus.NO_CONTENT.value, HTTPStatus.RESET_CONTENT.value, HTTPStatus.NOT_MODIFIED.value} +DEFAULT_CONTROLLER_DIRS = ("controllers/console", "controllers/service_api", "controllers/web") + +type Classification = Literal["valid", "mismatch", "unknown", "refactorable"] +type ActualKind = Literal[ + "empty", + "model", + "model_dump_variable", + "none", + "raw_dict", + "raw_list", + "raw_value", + "unknown", +] +type MethodNode = ast.FunctionDef | ast.AsyncFunctionDef + +HTTP_STATUS_NAMES = {status.name: status.value for status in HTTPStatus} +HTTP_STATUS_NAMES.update({f"HTTP_{status.value}_{status.name}": status.value for status in HTTPStatus}) + + +@dataclass(frozen=True) +class DocumentedResponse: + status: int + model: str | None + line: int + + +@dataclass(frozen=True) +class ActualResponse: + status: int | None + kind: ActualKind + model: str | None + line: int + + +@dataclass(frozen=True) +class ContractCheck: + classification: Classification + file: str + class_name: str + method: str + route: str + line: int + reason: str + documented: dict[int, str | None] + actual: list[ActualResponse] + + +@dataclass(frozen=True) +class ContractCheckContext: + """Stable route-method context shared by every classification result.""" + + file: str + class_name: str + method: str + route: str + line: int + documented: dict[int, str | None] + + def build( + self, classification: Classification, reason: str, actual_responses: Sequence[ActualResponse] + ) -> ContractCheck: + return ContractCheck( + classification=classification, + file=self.file, + class_name=self.class_name, + method=self.method, + route=self.route, + line=self.line, + reason=reason, + documented=self.documented, + actual=list(actual_responses), + ) + + def mismatch(self, reason: str, documented: DocumentedResponse, actual: ActualResponse) -> ContractCheck: + return self.build("mismatch", f"{reason} (doc line {documented.line}, return line {actual.line})", [actual]) + + +@dataclass +class VariableAssignmentSummary: + """Track whether a local name is safe to treat as one specific response model.""" + + known_models: set[str] = field(default_factory=set) + has_unknown_assignment: bool = False + + def add_known(self, model: str) -> None: + self.known_models.add(model) + + def add_unknown(self) -> None: + self.has_unknown_assignment = True + + def single_known_model(self) -> str | None: + if self.has_unknown_assignment or len(self.known_models) != 1: + return None + return next(iter(self.known_models)) + + +def dotted_name(node: ast.AST) -> str | None: + match node: + case ast.Name(): + return node.id + case ast.Attribute(): + parent = dotted_name(node.value) + if parent: + return f"{parent}.{node.attr}" + return node.attr + return None + + +def leaf_name(node: ast.AST) -> str | None: + name = dotted_name(node) + if name is None: + return None + return name.rsplit(".", 1)[-1] + + +def int_constant(node: ast.AST | None) -> int | None: + if isinstance(node, ast.Constant) and isinstance(node.value, int): + return node.value + if isinstance(node, ast.Name): + return HTTP_STATUS_NAMES.get(node.id) + if isinstance(node, ast.Attribute): + return HTTP_STATUS_NAMES.get(node.attr) + return None + + +def string_constant(node: ast.AST | None) -> str | None: + if isinstance(node, ast.Constant) and isinstance(node.value, str): + return node.value + return None + + +def keyword_value(call: ast.Call, *names: str) -> ast.AST | None: + for keyword in call.keywords: + if keyword.arg in names: + return keyword.value + return None + + +def is_probable_model_name(name: str) -> bool: + return bool(name) and name[0].isupper() + + +def model_name_from_schema_expr(node: ast.AST | None) -> str | None: + if node is None: + return None + + if isinstance(node, ast.Subscript): + value_name = dotted_name(node.value) + if value_name and value_name.endswith(".models"): + # register_response_schema_models stores schemas by model name; both + # ns.models[Model.__name__] and ns.models["Model"] appear in controllers. + key = node.slice + if isinstance(key, ast.Attribute) and key.attr == "__name__": + return leaf_name(key.value) + return string_constant(key) + + if isinstance(node, ast.Call): + func_name = dotted_name(node.func) + if func_name and func_name.endswith(".model"): + return string_constant(node.args[0] if node.args else keyword_value(node, "name")) + + if isinstance(node, ast.Name): + return node.id + + return None + + +def documented_response_from_decorator(decorator: ast.expr) -> DocumentedResponse | None: + if not isinstance(decorator, ast.Call): + return None + + func_name = dotted_name(decorator.func) + if not func_name or not func_name.endswith(".response"): + return None + + status_expr = decorator.args[0] if decorator.args else keyword_value(decorator, "code", "status") + status = int_constant(status_expr) + if status is None: + return None + + schema_expr: ast.AST | None = decorator.args[2] if len(decorator.args) >= 3 else None + schema_expr = keyword_value(decorator, "model", "schema") or schema_expr + + return DocumentedResponse( + status=status, + model=model_name_from_schema_expr(schema_expr), + line=decorator.lineno, + ) + + +def route_from_decorator(decorator: ast.expr) -> str | None: + if not isinstance(decorator, ast.Call): + return None + + func_name = dotted_name(decorator.func) + if not func_name or not func_name.endswith(".route"): + return None + + return string_constant(decorator.args[0] if decorator.args else keyword_value(decorator, "route", "path")) + + +def routes_from_decorators(decorators: Iterable[ast.expr]) -> list[str]: + return [route for decorator in decorators if (route := route_from_decorator(decorator))] + + +def response_docs_from_decorators(decorators: Iterable[ast.expr]) -> dict[int, DocumentedResponse]: + docs: dict[int, DocumentedResponse] = {} + for decorator in decorators: + response = documented_response_from_decorator(decorator) + if response and 200 <= response.status < 300: + docs[response.status] = response + return docs + + +def model_name_from_model_validate_call(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + if isinstance(node.func, ast.Attribute) and node.func.attr == "model_validate": + return leaf_name(node.func.value) + return None + + +def model_name_from_constructor_call(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + if isinstance(node.func, ast.Name) and is_probable_model_name(node.func.id): + return node.func.id + return None + + +def model_name_from_model_dump(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute) or node.func.attr != "model_dump": + return None + + dumped_value = node.func.value + if isinstance(dumped_value, ast.Call): + return model_name_from_model_validate_call(dumped_value) or model_name_from_constructor_call(dumped_value) + + return None + + +def model_name_from_model_value(node: ast.AST) -> str | None: + return model_name_from_model_validate_call(node) or model_name_from_constructor_call(node) + + +def model_name_from_dump_response(node: ast.AST) -> str | None: + if not isinstance(node, ast.Call): + return None + + func_name = dotted_name(node.func) + if func_name != "dump_response" and not (func_name and func_name.endswith(".dump_response")): + return None + + model_expr = node.args[0] if node.args else keyword_value(node, "model", "schema", "response_model") + if isinstance(model_expr, ast.Name): + return model_expr.id + return None + + +def actual_kind_from_expr( + expr: ast.AST | None, variable_models: dict[str, str] | None = None +) -> tuple[ActualKind, str | None]: + if expr is None: + return "none", None + + dump_response_model = model_name_from_dump_response(expr) + if dump_response_model: + return "model", dump_response_model + + if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Attribute) and expr.func.attr == "model_dump": + dumped_value = expr.func.value + if isinstance(dumped_value, ast.Name) and variable_models: + # A variable dump can match today, but it bypasses dump_response and + # is easier to drift; keep it visible as refactorable. + model_name = variable_models.get(dumped_value.id) + if model_name: + return "model_dump_variable", model_name + + model_dump_model = model_name_from_model_dump(expr) + if model_dump_model: + return "model", model_dump_model + + if isinstance(expr, ast.Constant): + if expr.value is None: + return "none", None + if expr.value == "": + return "empty", None + return "raw_value", None + + if isinstance(expr, ast.Dict): + return "raw_dict", None + + if isinstance(expr, ast.List): + return "raw_list", None + + return "unknown", None + + +def actual_response_from_return(return_node: ast.Return, variable_models: dict[str, str]) -> ActualResponse: + status: int | None = 200 + body_expr = return_node.value + + if isinstance(return_node.value, ast.Tuple) and return_node.value.elts: + body_expr = return_node.value.elts[0] + if len(return_node.value.elts) >= 2: + # Dynamic statuses are not safe to coerce to 200; classify them as unknown. + status = int_constant(return_node.value.elts[1]) + + kind, model = actual_kind_from_expr(body_expr, variable_models) + return ActualResponse(status=status, kind=kind, model=model, line=return_node.lineno) + + +def iter_method_nodes(method: MethodNode) -> Iterable[ast.AST]: + """Yield method body nodes while ignoring nested function/class scopes.""" + + stack: list[ast.AST] = list(reversed(method.body)) + while stack: + node = stack.pop() + yield node + + if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.Lambda | ast.ClassDef): + continue + + stack.extend(reversed(list(ast.iter_child_nodes(node)))) + + +def target_names(target: ast.AST) -> Iterable[str]: + if isinstance(target, ast.Name): + yield target.id + elif isinstance(target, ast.Tuple | ast.List): + for item in target.elts: + yield from target_names(item) + + +def record_assignment( + assignments: defaultdict[str, VariableAssignmentSummary], targets: Iterable[str], model_name: str | None +) -> None: + for target in targets: + if model_name is None: + # Once a name receives an unknown value, later model_dump() calls on it + # are no longer a reliable signal for the returned schema. + assignments[target].add_unknown() + else: + assignments[target].add_known(model_name) + + +def variable_model_assignments_for_method(method: MethodNode) -> dict[str, str]: + """Infer local variables that are unambiguously assigned one response model.""" + + assignments: defaultdict[str, VariableAssignmentSummary] = defaultdict(VariableAssignmentSummary) + + for node in iter_method_nodes(method): + match node: + case ast.Assign(targets=targets, value=value): + record_assignment( + assignments, + (name for target in targets for name in target_names(target)), + model_name_from_model_value(value), + ) + case ast.AnnAssign(target=target, value=value) if value is not None: + record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + case ast.AugAssign(target=target) | ast.For(target=target) | ast.AsyncFor(target=target): + # Mutation and loop targets overwrite prior values with runtime-dependent data. + record_assignment(assignments, target_names(target), None) + case ast.With(items=items) | ast.AsyncWith(items=items): + for item in items: + if item.optional_vars is not None: + record_assignment(assignments, target_names(item.optional_vars), None) + case ast.ExceptHandler(name=name) if name: + assignments[name].add_unknown() + case ast.NamedExpr(target=target, value=value): + record_assignment(assignments, target_names(target), model_name_from_model_value(value)) + + return {name: model for name, summary in assignments.items() if (model := summary.single_known_model()) is not None} + + +def actual_responses_for_method(method: MethodNode) -> list[ActualResponse]: + """Extract statically visible 2xx returns from one controller method. + + The analysis is deliberately shallow and conservative: + + 1. Walk only the method's own body, skipping nested functions/classes. + 2. Infer local variables that are assigned exactly one recognizable response + model, so ``response.model_dump()`` can still be connected to its schema. + 3. Treat any later unknown assignment, mutation target, loop target, context + manager binding, or exception binding as invalidating that variable. + 4. For each top-level return path, split Flask-style ``(body, status)`` + tuples, classify the body expression, and keep non-literal statuses as + ``None`` so the classifier reports them as unknown instead of assuming 200. + 5. Drop non-2xx literal statuses, since response contracts here only compare + successful response schemas. + """ + + variable_models = variable_model_assignments_for_method(method) + responses: list[ActualResponse] = [] + for node in iter_method_nodes(method): + if isinstance(node, ast.Return): + responses.append(actual_response_from_return(node, variable_models)) + return [response for response in responses if response.status is None or 200 <= response.status < 300] + + +def display_path(file_path: Path, repo_root: Path) -> str: + try: + return str(file_path.relative_to(repo_root)) + except ValueError: + return str(file_path) + + +def classify_method( + *, + actual_responses: Sequence[ActualResponse], + class_name: str, + documented_responses: dict[int, DocumentedResponse], + file_path: Path, + method: MethodNode, + repo_root: Path, + route: str, +) -> ContractCheck: + documented_summary = {status: response.model for status, response in sorted(documented_responses.items())} + context = ContractCheckContext( + file=display_path(file_path, repo_root), + class_name=class_name, + method=method.name, + route=route, + line=method.lineno, + documented=documented_summary, + ) + + if not actual_responses: + return context.build("unknown", "no statically visible 2xx return", []) + + unknown_reasons: list[str] = [] + refactorable_reasons: list[str] = [] + + for actual in actual_responses: + if actual.status is None: + unknown_reasons.append(f"return line {actual.line} has non-literal or unsupported status") + continue + + documented = documented_responses.get(actual.status) + + if actual.status in NO_BODY_STATUSES: + # No-body statuses are contract violations even when the schema names + # would otherwise match, because clients should not expect a payload. + if documented is not None and documented.model is not None: + return context.mismatch( + f"status {actual.status} is a no-body response but documents {documented.model}", + documented, + actual, + ) + if actual.kind in {"model", "model_dump_variable", "raw_dict", "raw_list", "raw_value"}: + no_body_doc = DocumentedResponse(status=actual.status, model=None, line=method.lineno) + return context.mismatch( + f"status {actual.status} is a no-body response but returns {actual.kind}", + no_body_doc, + actual, + ) + if actual.kind == "unknown": + unknown_reasons.append(f"status {actual.status} returns unknown body expression") + continue + + if documented is None: + unknown_reasons.append(f"status {actual.status} has no @response doc") + continue + + if documented.model is None: + unknown_reasons.append(f"status {actual.status} response doc has no schema model") + continue + + if actual.kind == "model_dump_variable" and actual.model is not None: + if documented.model != actual.model: + return context.mismatch( + f"status {actual.status} documents {documented.model} but returns {actual.model}", + documented, + actual, + ) + # The schema matches, but this path still deserves cleanup because + # dump_response is the contract-aware serialization helper. + refactorable_reasons.append( + f"status {actual.status} returns {actual.model}.model_dump() through a variable; prefer dump_response" + ) + continue + + if actual.kind != "model" or actual.model is None: + unknown_reasons.append(f"status {actual.status} returns {actual.kind}") + continue + + if documented.model != actual.model: + return context.mismatch( + f"status {actual.status} documents {documented.model} but returns {actual.model}", + documented, + actual, + ) + + if unknown_reasons: + # Unknown beats refactorable: if any return path is ambiguous, do not + # imply the endpoint is merely a cleanup candidate. + return context.build("unknown", "; ".join(sorted(set(unknown_reasons))), actual_responses) + + if refactorable_reasons: + return context.build("refactorable", "; ".join(sorted(set(refactorable_reasons))), actual_responses) + + return context.build( + "valid", + "documented response schema matches statically visible return schema", + actual_responses, + ) + + +def iter_controller_files(paths: Iterable[Path]) -> Iterable[Path]: + for path in paths: + if path.is_file() and path.suffix == ".py": + yield path + elif path.is_dir(): + yield from sorted(child for child in path.rglob("*.py") if child.is_file()) + + +def checks_for_file(file_path: Path, repo_root: Path) -> list[ContractCheck]: + module = ast.parse(file_path.read_text(encoding="utf-8"), filename=str(file_path)) + checks: list[ContractCheck] = [] + + for node in module.body: + if not isinstance(node, ast.ClassDef): + continue + + class_routes = routes_from_decorators(node.decorator_list) + class_documented = response_docs_from_decorators(node.decorator_list) + + for item in node.body: + if not isinstance(item, ast.FunctionDef | ast.AsyncFunctionDef) or item.name not in HTTP_METHODS: + continue + + routes = routes_from_decorators(item.decorator_list) or class_routes + if not routes: + continue + + documented = {**class_documented, **response_docs_from_decorators(item.decorator_list)} + # Method-level @response decorators override class-level defaults for + # the same status code, matching Flask-RESTX's common controller style. + actual = actual_responses_for_method(item) + for route in routes: + checks.append( + classify_method( + actual_responses=actual, + class_name=node.name, + documented_responses=documented, + file_path=file_path, + method=item, + repo_root=repo_root, + route=route, + ) + ) + + return checks + + +def as_jsonable(check: ContractCheck) -> dict[str, Any]: + data = asdict(check) + data["documented"] = {str(status): model for status, model in check.documented.items()} + return data + + +def print_text_report(checks: Sequence[ContractCheck], *, include_valid: bool) -> None: + counts = Counter(check.classification for check in checks) + sys.stdout.write( + "Response contract lint: " + f"{counts['valid']} valid, " + f"{counts['mismatch']} mismatch, " + f"{counts['refactorable']} refactorable, " + f"{counts['unknown']} unknown\n" + ) + + for classification in ("mismatch", "refactorable", "unknown", "valid"): + filtered = [check for check in checks if check.classification == classification] + if classification == "valid" and not include_valid: + continue + if not filtered: + continue + + sys.stdout.write(f"\n{classification.upper()}:\n") + for check in filtered: + location = f"{check.file}:{check.line} {check.class_name}.{check.method.upper()} {check.route}" + sys.stdout.write(f"- {location}: {check.reason}\n") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "paths", + nargs="*", + help="Files or directories to lint. Defaults to Flask controller directories.", + ) + parser.add_argument("--include-valid", action="store_true", help="Print valid route methods in text output.") + parser.add_argument("--json", action="store_true", help="Emit machine-readable JSON.") + parser.add_argument( + "--fail-on-mismatch", + action="store_true", + help="Treat mismatched response contracts as failures. By default this linter is report-only.", + ) + parser.add_argument( + "--fail-on-unknown", + action="store_true", + help="Treat unknown route methods as failures. By default this linter is report-only.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + api_root = Path(__file__).resolve().parents[1] + repo_root = api_root.parent + raw_paths = args.paths or list(DEFAULT_CONTROLLER_DIRS) + paths = [path if path.is_absolute() else api_root / path for path in map(Path, raw_paths)] + + checks: list[ContractCheck] = [] + for file_path in iter_controller_files(paths): + checks.extend(checks_for_file(file_path.resolve(), repo_root)) + + checks.sort(key=lambda check: (check.classification, check.file, check.line, check.method)) + + if args.json: + grouped = defaultdict(list) + for check in checks: + grouped[check.classification].append(as_jsonable(check)) + sys.stdout.write(f"{json.dumps(grouped, indent=2, sort_keys=True)}\n") + else: + print_text_report(checks, include_valid=bool(args.include_valid)) + + has_mismatch = any(check.classification == "mismatch" for check in checks) + has_unknown = any(check.classification == "unknown" for check in checks) + return int((bool(args.fail_on_mismatch) and has_mismatch) or (bool(args.fail_on_unknown) and has_unknown)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py deleted file mode 100644 index a2dda1dc15a..00000000000 --- a/api/fields/api_based_extension_fields.py +++ /dev/null @@ -1,23 +0,0 @@ -from flask_restx import fields - -from libs.helper import TimestampField - - -class HiddenAPIKey(fields.Raw): - def output(self, key, obj, **kwargs): - api_key = obj.api_key - # If the length of the api_key is less than 8 characters, show the first and last characters - if len(api_key) <= 8: - return api_key[0] + "******" + api_key[-1] - # If the api_key is greater than 8 characters, show the first three and the last three characters - else: - return api_key[:3] + "******" + api_key[-3:] - - -api_based_extension_fields = { - "id": fields.String, - "name": fields.String, - "api_endpoint": fields.String, - "api_key": HiddenAPIKey, - "created_at": TimestampField, -} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 55c66b4f1db..8e0d5a39af5 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -1,5 +1,6 @@ from flask_restx import fields +from fields.base import ResponseModel from libs.helper import TimestampField dataset_fields = { @@ -14,6 +15,38 @@ dataset_fields = { "permission_keys": fields.List(fields.String), } + +class DatasetMetadataResponse(ResponseModel): + id: str + type: str + name: str + + +class DatasetMetadataListItemResponse(ResponseModel): + id: str + name: str + type: str + count: int = 0 + + +class DatasetMetadataListResponse(ResponseModel): + doc_metadata: list[DatasetMetadataListItemResponse] + built_in_field_enabled: bool + + +class DatasetMetadataBuiltInFieldResponse(ResponseModel): + name: str + type: str + + +class DatasetMetadataBuiltInFieldsResponse(ResponseModel): + fields: list[DatasetMetadataBuiltInFieldResponse] + + +class DatasetMetadataActionResponse(ResponseModel): + result: str + + reranking_model_fields = {"reranking_provider_name": fields.String, "reranking_model_name": fields.String} keyword_setting_fields = {"keyword_weight": fields.Float} @@ -135,9 +168,3 @@ dataset_query_detail_fields = { "created_by": fields.String, "created_at": TimestampField, } - -dataset_metadata_fields = { - "id": fields.String, - "type": fields.String, - "name": fields.String, -} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py deleted file mode 100644 index 16dd26a10e4..00000000000 --- a/api/fields/installed_app_fields.py +++ /dev/null @@ -1,26 +0,0 @@ -from flask_restx import fields - -from libs.helper import AppIconUrlField, TimestampField - -app_fields = { - "id": fields.String, - "name": fields.String, - "mode": fields.String, - "icon_type": fields.String, - "icon": fields.String, - "icon_background": fields.String, - "icon_url": AppIconUrlField, - "use_icon_as_answer_icon": fields.Boolean, -} - -installed_app_fields = { - "id": fields.String, - "app": fields.Nested(app_fields), - "app_owner_tenant_id": fields.String, - "is_pinned": fields.Boolean, - "last_used_at": TimestampField, - "editable": fields.Boolean, - "uninstallable": fields.Boolean, -} - -installed_app_list_fields = {"installed_apps": fields.List(fields.Nested(installed_app_fields))} diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py deleted file mode 100644 index a70f051807b..00000000000 --- a/api/fields/workflow_app_log_fields.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -from datetime import datetime -from typing import Any - -from flask_restx import Namespace, fields -from pydantic import field_validator - -from fields.base import ResponseModel -from fields.end_user_fields import SimpleEndUser, simple_end_user_fields -from fields.member_fields import SimpleAccount, simple_account_fields -from fields.workflow_run_fields import ( - WorkflowRunForArchivedLogResponse, - WorkflowRunForLogResponse, - build_workflow_run_for_archived_log_model, - build_workflow_run_for_log_model, - workflow_run_for_archived_log_fields, - workflow_run_for_log_fields, -) -from libs.helper import TimestampField, to_timestamp - -workflow_app_log_partial_fields = { - "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_log_fields, attribute="workflow_run", allow_null=True), - "details": fields.Raw(attribute="details"), - "created_from": fields.String, - "created_by_role": fields.String, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "created_at": TimestampField, -} - - -def build_workflow_app_log_partial_model(api_or_ns: Namespace): - """Build the workflow app log partial model for the API or Namespace.""" - workflow_run_model = build_workflow_run_for_log_model(api_or_ns) - - copied_fields = workflow_app_log_partial_fields.copy() - copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) - return api_or_ns.model("WorkflowAppLogPartial", copied_fields) - - -workflow_archived_log_partial_fields = { - "id": fields.String, - "workflow_run": fields.Nested(workflow_run_for_archived_log_fields, allow_null=True), - "trigger_metadata": fields.Raw, - "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), - "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), - "created_at": TimestampField, -} - - -def build_workflow_archived_log_partial_model(api_or_ns: Namespace): - """Build the workflow archived log partial model for the API or Namespace.""" - workflow_run_model = build_workflow_run_for_archived_log_model(api_or_ns) - - copied_fields = workflow_archived_log_partial_fields.copy() - copied_fields["workflow_run"] = fields.Nested(workflow_run_model, allow_null=True) - return api_or_ns.model("WorkflowArchivedLogPartial", copied_fields) - - -workflow_app_log_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer, - "total": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(workflow_app_log_partial_fields)), -} - - -def build_workflow_app_log_pagination_model(api_or_ns: Namespace): - """Build the workflow app log pagination model for the API or Namespace.""" - # Build the nested partial model first - workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) - - copied_fields = workflow_app_log_pagination_fields.copy() - copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) - return api_or_ns.model("WorkflowAppLogPagination", copied_fields) - - -workflow_archived_log_pagination_fields = { - "page": fields.Integer, - "limit": fields.Integer, - "total": fields.Integer, - "has_more": fields.Boolean, - "data": fields.List(fields.Nested(workflow_archived_log_partial_fields)), -} - - -def build_workflow_archived_log_pagination_model(api_or_ns: Namespace): - """Build the workflow archived log pagination model for the API or Namespace.""" - workflow_archived_log_partial_model = build_workflow_archived_log_partial_model(api_or_ns) - - copied_fields = workflow_archived_log_pagination_fields.copy() - copied_fields["data"] = fields.List(fields.Nested(workflow_archived_log_partial_model)) - return api_or_ns.model("WorkflowArchivedLogPagination", copied_fields) - - -class WorkflowAppLogPartialResponse(ResponseModel): - id: str - workflow_run: WorkflowRunForLogResponse | None = None - details: Any = None - created_from: str | None = None - created_by_role: str | None = None - created_by_account: SimpleAccount | None = None - created_by_end_user: SimpleEndUser | None = None - created_at: int | None = None - - @field_validator("created_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return to_timestamp(value) - - -class WorkflowArchivedLogPartialResponse(ResponseModel): - id: str - workflow_run: WorkflowRunForArchivedLogResponse | None = None - trigger_metadata: Any = None - created_by_account: SimpleAccount | None = None - created_by_end_user: SimpleEndUser | None = None - created_at: int | None = None - - @field_validator("created_at", mode="before") - @classmethod - def _normalize_timestamp(cls, value: datetime | int | None) -> int | None: - return to_timestamp(value) - - -class WorkflowAppLogPaginationResponse(ResponseModel): - page: int - limit: int - total: int - has_more: bool - data: list[WorkflowAppLogPartialResponse] - - -class WorkflowArchivedLogPaginationResponse(ResponseModel): - page: int - limit: int - total: int - has_more: bool - data: list[WorkflowArchivedLogPartialResponse] diff --git a/api/fields/workflow_trigger_fields.py b/api/fields/workflow_trigger_fields.py deleted file mode 100644 index ce51d1833a3..00000000000 --- a/api/fields/workflow_trigger_fields.py +++ /dev/null @@ -1,25 +0,0 @@ -from flask_restx import fields - -trigger_fields = { - "id": fields.String, - "trigger_type": fields.String, - "title": fields.String, - "node_id": fields.String, - "provider_name": fields.String, - "icon": fields.String, - "status": fields.String, - "created_at": fields.DateTime(dt_format="iso8601"), - "updated_at": fields.DateTime(dt_format="iso8601"), -} - -triggers_list_fields = {"data": fields.List(fields.Nested(trigger_fields))} - - -webhook_trigger_fields = { - "id": fields.String, - "webhook_id": fields.String, - "webhook_url": fields.String, - "webhook_debug_url": fields.String, - "node_id": fields.String, - "created_at": fields.DateTime(dt_format="iso8601"), -} diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index cc8c1e1d6b2..8536cc93ae2 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -4413,9 +4413,9 @@ Initialize dataset with documents #### GET ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Built-in fields retrieved successfully | [DatasetMetadataBuiltInFieldsResponse](#datasetmetadatabuiltinfieldsresponse) | ### /datasets/notion-indexing-estimate @@ -4730,9 +4730,9 @@ then asynchronously generates summary indexes for the provided documents. ##### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 200 | Success | [SimpleResultResponse](#simpleresultresponse) | +| Code | Description | +| ---- | ----------- | +| 204 | Documents metadata updated successfully | ### /datasets/{dataset_id}/documents/status/{action}/batch @@ -5342,9 +5342,9 @@ Get dataset indexing status ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Metadata retrieved successfully | [DatasetMetadataListResponse](#datasetmetadatalistresponse) | #### POST ##### Parameters @@ -5356,9 +5356,9 @@ Get dataset indexing status ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 201 | Metadata created successfully | [DatasetMetadataResponse](#datasetmetadataresponse) | ### /datasets/{dataset_id}/metadata/built-in/{action} @@ -5372,9 +5372,9 @@ Get dataset indexing status ##### Responses -| Code | Description | Schema | -| ---- | ----------- | ------ | -| 200 | Success | [SimpleResultResponse](#simpleresultresponse) | +| Code | Description | +| ---- | ----------- | +| 204 | Action completed successfully | ### /datasets/{dataset_id}/metadata/{metadata_id} @@ -5403,9 +5403,9 @@ Get dataset indexing status ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Success | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Metadata updated successfully | [DatasetMetadataResponse](#datasetmetadataresponse) | ### /datasets/{dataset_id}/notion/sync @@ -5693,6 +5693,23 @@ Get feature configuration for current tenant | ---- | ----------- | ------ | | 200 | Success | [FeatureModel](#featuremodel) | +### /features/vector-space + +#### GET +##### Summary + +Get vector-space usage and limit for current tenant + +##### Description + +Get vector-space usage and limit for current tenant + +##### Responses + +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Success | [LimitationModel](#limitationmodel) | + ### /files/support-type #### GET @@ -11733,6 +11750,43 @@ Condition detail | ---- | ---- | ----------- | -------- | | keyword_weight | number | | No | +#### DatasetMetadataBuiltInFieldResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| name | string | | Yes | +| type | string | | Yes | + +#### DatasetMetadataBuiltInFieldsResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| fields | [ [DatasetMetadataBuiltInFieldResponse](#datasetmetadatabuiltinfieldresponse) ] | | Yes | + +#### DatasetMetadataListItemResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| count | integer | | No | +| id | string | | Yes | +| name | string | | Yes | +| type | string | | Yes | + +#### DatasetMetadataListResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| built_in_field_enabled | boolean | | Yes | +| doc_metadata | [ [DatasetMetadataListItemResponse](#datasetmetadatalistitemresponse) ] | | Yes | + +#### DatasetMetadataResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| id | string | | Yes | +| name | string | | Yes | +| type | string | | Yes | + #### DatasetPermissionEnum | Name | Type | Description | Required | diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index 87dbe8c1ba5..17058e5e621 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -877,7 +877,7 @@ Update metadata for multiple documents | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Documents metadata updated successfully | [SimpleResultResponse](#simpleresultresponse) | +| 200 | Documents metadata updated successfully | [DatasetMetadataActionResponse](#datasetmetadataactionresponse) | | 401 | Unauthorized - invalid API token | | | 404 | Dataset not found | | @@ -1378,11 +1378,11 @@ Get all metadata for a dataset ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Metadata retrieved successfully | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Metadata retrieved successfully | [DatasetMetadataListResponse](#datasetmetadatalistresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | #### POST ##### Summary @@ -1402,11 +1402,11 @@ Create metadata for a dataset ##### Responses -| Code | Description | -| ---- | ----------- | -| 201 | Metadata created successfully | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 201 | Metadata created successfully | [DatasetMetadataResponse](#datasetmetadataresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset not found | | ### /datasets/{dataset_id}/metadata/built-in @@ -1427,10 +1427,10 @@ Get all built-in metadata fields ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Built-in fields retrieved successfully | -| 401 | Unauthorized - invalid API token | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Built-in fields retrieved successfully | [DatasetMetadataBuiltInFieldsResponse](#datasetmetadatabuiltinfieldsresponse) | +| 401 | Unauthorized - invalid API token | | ### /datasets/{dataset_id}/metadata/built-in/{action} @@ -1454,7 +1454,7 @@ Enable or disable built-in metadata field | Code | Description | Schema | | ---- | ----------- | ------ | -| 200 | Action completed successfully | [SimpleResultResponse](#simpleresultresponse) | +| 200 | Action completed successfully | [DatasetMetadataActionResponse](#datasetmetadataactionresponse) | | 401 | Unauthorized - invalid API token | | | 404 | Dataset not found | | @@ -1503,11 +1503,11 @@ Update metadata name ##### Responses -| Code | Description | -| ---- | ----------- | -| 200 | Metadata updated successfully | -| 401 | Unauthorized - invalid API token | -| 404 | Dataset or metadata not found | +| Code | Description | Schema | +| ---- | ----------- | ------ | +| 200 | Metadata updated successfully | [DatasetMetadataResponse](#datasetmetadataresponse) | +| 401 | Unauthorized - invalid API token | | +| 404 | Dataset or metadata not found | | ### /datasets/{dataset_id}/pipeline/datasource-plugins @@ -2314,6 +2314,49 @@ Condition detail | page | integer | Page number | No | | tag_ids | [ string ] | Filter by tag IDs | No | +#### DatasetMetadataActionResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| result | string | | Yes | + +#### DatasetMetadataBuiltInFieldResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| name | string | | Yes | +| type | string | | Yes | + +#### DatasetMetadataBuiltInFieldsResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| fields | [ [DatasetMetadataBuiltInFieldResponse](#datasetmetadatabuiltinfieldresponse) ] | | Yes | + +#### DatasetMetadataListItemResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| count | integer | | No | +| id | string | | Yes | +| name | string | | Yes | +| type | string | | Yes | + +#### DatasetMetadataListResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| built_in_field_enabled | boolean | | Yes | +| doc_metadata | [ [DatasetMetadataListItemResponse](#datasetmetadatalistitemresponse) ] | | Yes | + +#### DatasetMetadataResponse + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| id | string | | Yes | +| name | string | | Yes | +| type | string | | Yes | + #### DatasetPermissionEnum | Name | Type | Description | Required | diff --git a/api/services/billing_service.py b/api/services/billing_service.py index c0e23cdc6f1..6021d46c724 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -116,7 +116,7 @@ class BillingInfo(TypedDict): subscription: _BillingSubscription members: _BillingQuota apps: _BillingQuota - vector_space: _VectorSpaceQuota + vector_space: NotRequired[_VectorSpaceQuota] knowledge_rate_limit: _KnowledgeRateLimit documents_upload_quota: _BillingQuota annotation_quota_limit: _BillingQuota @@ -128,6 +128,7 @@ class BillingInfo(TypedDict): _billing_info_adapter = TypeAdapter(BillingInfo) +_vector_space_quota_adapter = TypeAdapter(_VectorSpaceQuota) class KnowledgeRateLimitDict(TypedDict): @@ -185,12 +186,21 @@ class BillingService: _PLAN_CACHE_TTL = 600 @classmethod - def get_info(cls, tenant_id: str) -> BillingInfo: + def get_info(cls, tenant_id: str, exclude_vector_space: bool = False) -> BillingInfo: params = {"tenant_id": tenant_id} + if exclude_vector_space: + params["exclude_vector_space"] = "true" billing_info = cls._send_request("GET", "/subscription/info", params=params) return _billing_info_adapter.validate_python(billing_info) + @classmethod + def get_vector_space(cls, tenant_id: str) -> _VectorSpaceQuota: + params = {"tenant_id": tenant_id} + return _vector_space_quota_adapter.validate_python( + cls._send_request("GET", "/subscription/vector-space", params=params) + ) + @classmethod def get_tenant_feature_plan_usage_info(cls, tenant_id: str): """Deprecated: Use get_quota_info instead.""" diff --git a/api/services/feature_service.py b/api/services/feature_service.py index ce05df74c3e..ccef3415f9d 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,7 +6,7 @@ from configs import dify_config from constants.dsl_version import CURRENT_APP_DSL_VERSION from enums.cloud_plan import CloudPlan from enums.hosted_provider import HostedTrialProvider -from services.billing_service import BillingService +from services.billing_service import BillingInfo, BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -187,13 +187,17 @@ class SystemFeatureModel(FeatureResponseModel): class FeatureService: @classmethod - def get_features(cls, tenant_id: str) -> FeatureModel: + def get_features(cls, tenant_id: str, exclude_vector_space: bool = False) -> FeatureModel: features = FeatureModel() cls._fulfill_params_from_env(features) if dify_config.BILLING_ENABLED and tenant_id: - cls._fulfill_params_from_billing_api(features, tenant_id) + cls._fulfill_params_from_billing_api( + features, + tenant_id, + exclude_vector_space=exclude_vector_space, + ) if dify_config.ENTERPRISE_ENABLED: features.webapp_copyright_enabled = True @@ -207,6 +211,18 @@ class FeatureService: return features + @classmethod + def get_vector_space(cls, tenant_id: str) -> LimitationModel: + vector_space = LimitationModel(size=0, limit=5) + if dify_config.BILLING_ENABLED and tenant_id: + billing_vector_space = BillingService.get_vector_space(tenant_id) + # NOTE: billing API returns vector_space.size as float (e.g. 0.0), + # but feature API keeps LimitationModel.size as int for compatibility. + vector_space.size = int(billing_vector_space["size"]) + vector_space.limit = billing_vector_space["limit"] + + return vector_space + @classmethod def get_knowledge_rate_limit(cls, tenant_id: str): knowledge_rate_limit = KnowledgeRateLimitModel() @@ -291,8 +307,16 @@ class FeatureService: features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"] @classmethod - def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): - billing_info = BillingService.get_info(tenant_id) + def _fulfill_params_from_billing_api( + cls, + features: FeatureModel, + tenant_id: str, + exclude_vector_space: bool = False, + ): + if exclude_vector_space: + billing_info = BillingService.get_info(tenant_id, exclude_vector_space=True) + else: + billing_info = BillingService.get_info(tenant_id) features_usage_info = BillingService.get_quota_info(tenant_id) @@ -324,12 +348,8 @@ class FeatureService: features.apps.size = billing_info["apps"]["size"] features.apps.limit = billing_info["apps"]["limit"] - if "vector_space" in billing_info: - # NOTE (hj24): billing API returns vector_space.size as float (e.g. 0.0) - # but LimitationModel.size is int; truncate here for compatibility - features.vector_space.size = int(billing_info["vector_space"]["size"]) - # NOTE END - features.vector_space.limit = billing_info["vector_space"]["limit"] + if not exclude_vector_space: + cls._fulfill_vector_space_from_billing_info(features.vector_space, billing_info) if "documents_upload_quota" in billing_info: features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] @@ -361,6 +381,16 @@ class FeatureService: if "next_credit_reset_date" in billing_info: features.next_credit_reset_date = billing_info["next_credit_reset_date"] + @classmethod + def _fulfill_vector_space_from_billing_info(cls, vector_space: LimitationModel, billing_info: BillingInfo): + if "vector_space" not in billing_info: + return + + # NOTE: billing API returns vector_space.size as float (e.g. 0.0), + # but feature API keeps LimitationModel.size as int for compatibility. + vector_space.size = int(billing_info["vector_space"]["size"]) + vector_space.limit = billing_info["vector_space"]["limit"] + @classmethod def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel, is_authenticated: bool = False): enterprise_info = EnterpriseService.get_info() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a5ae83739cc..3c496d1fc8c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -45,7 +45,7 @@ class TestGetOAuthProviders: ) @patch("controllers.console.auth.oauth.dify_config") def test_should_configure_oauth_providers_correctly( - self, mock_config, app, github_config, google_config, expected_github, expected_google + self, mock_config, app: Flask, github_config, google_config, expected_github, expected_google ): mock_config.GITHUB_CLIENT_ID = github_config["id"] mock_config.GITHUB_CLIENT_SECRET = github_config["secret"] @@ -89,7 +89,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, invite_token, @@ -114,7 +114,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, ): @@ -136,7 +136,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, ): @@ -212,7 +212,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -237,7 +237,9 @@ class TestOAuthCallback: ], ) @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): + def test_should_handle_oauth_exceptions( + self, mock_get_providers, resource: OAuthCallback, app: Flask, exception, expected_error + ): # Import the real requests module to create a proper exception import httpx @@ -265,7 +267,7 @@ class TestOAuthCallback: mock_register_service, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -310,7 +312,7 @@ class TestOAuthCallback: mock_config, mock_tenant_service, mock_account_service, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, account_status, @@ -349,7 +351,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -385,7 +387,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -460,7 +462,12 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Account") def test_should_get_account_by_openid_or_email( - self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account + self, + mock_account_model, + mock_get_account, + flask_req_ctx_with_containers, + user_info: OAuthUserInfo, + mock_account, ): # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account @@ -516,7 +523,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, mock_account, allow_register, existing_account, @@ -592,7 +599,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, ): mock_feature_service.get_system_features.return_value.is_allow_register = True mock_register_service.register.return_value = MagicMock() @@ -623,7 +630,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, ): mock_feature_service.get_system_features.return_value.is_allow_register = True mock_register_service.register.return_value = MagicMock() @@ -654,7 +661,7 @@ class TestAccountGeneration: mock_tenant_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, mock_account, ): mock_get_account.return_value = mock_account diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 917aa35fe68..b5f5917ee99 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -128,7 +128,7 @@ class TestConversationApi: body, status = result assert status == 204 - assert body["result"] == "success" + assert body == "" def test_delete_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index d944613886d..b977a3eb7ab 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from flask.testing import FlaskClient from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -73,7 +74,9 @@ def client(flask_app_with_containers: Flask): @patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True) @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") -def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): +def test_create_mcp_provider_populates_tools( + mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient +): # Arrange: reconnect returns tools immediately mock_reconnect.return_value = ReconnectResult( authed=True, diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py index c34da27ebe4..0ec399ba2b5 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py @@ -81,7 +81,7 @@ class TestConversationApi: result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) assert status == 204 - assert result["result"] == "success" + assert result == "" @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: diff --git a/api/tests/unit_tests/commands/test_lint_response_contracts.py b/api/tests/unit_tests/commands/test_lint_response_contracts.py new file mode 100644 index 00000000000..8f3860f2318 --- /dev/null +++ b/api/tests/unit_tests/commands/test_lint_response_contracts.py @@ -0,0 +1,191 @@ +import importlib.util +import sys +from pathlib import Path + + +def _load_lint_response_contracts_module(): + api_dir = Path(__file__).parents[3] + script_path = api_dir / "dev" / "lint_response_contracts.py" + spec = importlib.util.spec_from_file_location("lint_response_contracts", script_path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _checks_for_source(tmp_path: Path, source: str): + module = _load_lint_response_contracts_module() + controller_path = tmp_path / "controllers" / "sample.py" + controller_path.parent.mkdir() + controller_path.write_text(source, encoding="utf-8") + return module.checks_for_file(controller_path, tmp_path) + + +def test_no_body_status_with_body_is_mismatch_while_empty_body_is_valid(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/bad") +class BadDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return {"result": "success"}, 204 + + +@ns.route("/ok") +class EmptyDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return "", 204 +""", + ) + + assert [(check.class_name, check.classification) for check in checks] == [ + ("BadDeleteApi", "mismatch"), + ("EmptyDeleteApi", "valid"), + ] + assert "no-body response but returns raw_dict" in checks[0].reason + + +def test_variable_model_dump_is_refactorable_not_valid(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +from http import HTTPStatus + + +@ns.route("/annotations") +class AnnotationApi(Resource): + @ns.response(HTTPStatus.CREATED, "Created", ns.models[AnnotationResponse.__name__]) + def post(self): + if use_existing: + response = AnnotationResponse.model_validate(existing, from_attributes=True) + else: + response = AnnotationResponse(id="new") + return response.model_dump(mode="json"), HTTPStatus.CREATED +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "refactorable" + assert checks[0].actual[0].status == 201 + assert checks[0].actual[0].kind == "model_dump_variable" + assert "prefer dump_response" in checks[0].reason + + +def test_variable_model_dump_with_wrong_documented_schema_is_mismatch(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/annotations") +class AnnotationApi(Resource): + @ns.response(200, "OK", ns.models[DocumentedResponse.__name__]) + def get(self): + response = ActualResponse.model_validate(data) + return response.model_dump(mode="json"), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "mismatch" + assert "documents DocumentedResponse but returns ActualResponse" in checks[0].reason + + +def test_nested_returns_are_ignored_for_outer_control_flow(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/stream") +class StreamApi(Resource): + @ns.response(200, "OK", ns.models[StreamResponse.__name__]) + def get(self): + def generate_events(): + return dump_response(WrongResponse, {"event": "nested"}), 200 + + if finished: + return dump_response(StreamResponse, {"event": "done"}), 200 + return dump_response(StreamResponse, {"event": "running"}), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "valid" + assert {actual.model for actual in checks[0].actual} == {"StreamResponse"} + + +def test_main_is_report_only_by_default_for_mismatches(tmp_path: Path, monkeypatch): + module = _load_lint_response_contracts_module() + controller_path = tmp_path / "controllers" / "sample.py" + controller_path.parent.mkdir() + controller_path.write_text( + """ +@ns.route("/bad") +class BadDeleteApi(Resource): + @ns.response(204, "Deleted") + def delete(self): + return {"result": "success"}, 204 +""", + encoding="utf-8", + ) + + monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", str(controller_path)]) + assert module.main() == 0 + + monkeypatch.setattr(sys, "argv", ["lint_response_contracts.py", "--fail-on-mismatch", str(controller_path)]) + assert module.main() == 1 + + +def test_class_level_route_and_response_docs_apply_to_methods(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route(path="/items") +@ns.response(code=200, description="OK", model=ns.models[ItemListResponse.__name__]) +class ItemListApi(Resource): + def get(self): + return dump_response(ItemListResponse, {"data": []}), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "valid" + assert checks[0].route == "/items" + + +def test_unknown_reassignment_prevents_variable_model_dump_inference(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/items") +class ItemApi(Resource): + @ns.response(200, "OK", ns.models[ItemResponse.__name__]) + def get(self): + response = ItemResponse.model_validate(item) + if refresh: + response = load_response() + return response.model_dump(mode="json"), 200 +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "unknown" + assert "returns unknown" in checks[0].reason + + +def test_non_literal_status_is_unknown_not_defaulted_to_200(tmp_path: Path): + checks = _checks_for_source( + tmp_path, + """ +@ns.route("/items") +class ItemApi(Resource): + @ns.response(200, "OK", ns.models[ItemResponse.__name__]) + def get(self): + return dump_response(ItemResponse, item), status_code +""", + ) + + assert len(checks) == 1 + assert checks[0].classification == "unknown" + assert "non-literal or unsupported status" in checks[0].reason diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index e28d68ee5a7..4b0dff037f8 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -873,7 +873,7 @@ class TestDatasetApiDelete: result, status = method(api, dataset_id) assert status == 204 - assert result == {"result": "success"} + assert result == "" def test_delete_forbidden_no_permission(self, app: Flask): api = DatasetApi() @@ -1687,7 +1687,7 @@ class TestDatasetApiDeleteApi: response, status = method(api, "api-key-id") assert status == 204 - assert response["result"] == "success" + assert response == "" def test_delete_key_not_found(self, app: Flask): api = DatasetApiDeleteApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 66d257ee666..eb99c4eab3a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -940,7 +940,7 @@ class TestChildChunkUpdateApi: response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1") assert status == 204 - assert response["result"] == "success" + assert response == "" def test_delete_child_chunk_index_error(self, app: Flask): api = ChildChunkUpdateApi() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index 63221335366..b2863fc8cda 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -117,12 +117,13 @@ class TestDatasetMetadataCreateApi: patch.object( MetadataService, "create_metadata", - return_value={"id": "m1", "name": "author"}, + return_value={"id": "m1", "type": "string", "name": "author"}, ), ): result, status = method(api, dataset_id) assert status == 201 + assert result["type"] == "string" assert result["name"] == "author" def test_create_metadata_dataset_not_found(self, app: Flask, current_user, dataset_id): @@ -176,13 +177,17 @@ class TestDatasetMetadataGetApi: patch.object( MetadataService, "get_dataset_metadatas", - return_value=[{"id": "m1"}], + return_value={ + "doc_metadata": [{"id": "m1", "name": "author", "type": "string", "count": 0}], + "built_in_field_enabled": False, + }, ), ): result, status = method(api, dataset_id) assert status == 200 - assert isinstance(result, list) + assert result["doc_metadata"] == [{"id": "m1", "name": "author", "type": "string", "count": 0}] + assert result["built_in_field_enabled"] is False def test_get_metadata_dataset_not_found(self, app: Flask, dataset_id): api = DatasetMetadataCreateApi() @@ -231,12 +236,13 @@ class TestDatasetMetadataApi: patch.object( MetadataService, "update_metadata_name", - return_value={"id": "m1", "name": "updated-name"}, + return_value={"id": "m1", "type": "string", "name": "updated-name"}, ), ): result, status = method(api, dataset_id, metadata_id) assert status == 200 + assert result["type"] == "string" assert result["name"] == "updated-name" def test_delete_metadata_success(self, app: Flask, current_user, dataset, dataset_id, metadata_id): @@ -266,7 +272,7 @@ class TestDatasetMetadataApi: result, status = method(api, dataset_id, metadata_id) assert status == 204 - assert result["result"] == "success" + assert result == "" class TestDatasetMetadataBuiltInFieldApi: @@ -279,13 +285,19 @@ class TestDatasetMetadataBuiltInFieldApi: patch.object( MetadataService, "get_built_in_fields", - return_value=["title", "source"], + return_value=[ + {"name": "document_name", "type": "string"}, + {"name": "source", "type": "string"}, + ], ), ): result, status = method(api) assert status == 200 - assert result["fields"] == ["title", "source"] + assert result["fields"] == [ + {"name": "document_name", "type": "string"}, + {"name": "source", "type": "string"}, + ] class TestDatasetMetadataBuiltInFieldActionApi: @@ -315,8 +327,8 @@ class TestDatasetMetadataBuiltInFieldActionApi: ): result, status = method(api, dataset_id, "enable") - assert status == 200 - assert result["result"] == "success" + assert status == 204 + assert result == "" class TestDocumentMetadataEditApi: @@ -359,5 +371,5 @@ class TestDocumentMetadataEditApi: ): result, status = method(api, dataset_id) - assert status == 200 - assert result["result"] == "success" + assert status == 204 + assert result == "" diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py index ec82803be4f..47ac8d8f3f7 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -327,7 +327,7 @@ class TestInstalledAppApi: resp, status = method(installed_app) assert status == 204 - assert resp["result"] == "success" + assert resp == "" def test_delete_owned_by_current_tenant(self, tenant_id: str): api = module.InstalledAppApi() diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index 49e5695e60c..00c0d91d1db 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -141,7 +141,7 @@ class TestSavedMessageApi: delete_mock.assert_called_once() assert status == 204 - assert result == {"result": "success"} + assert result == "" def test_delete_not_completion_app(self): api = module.SavedMessageApi() diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 60a7ea5bb56..20fc62073b9 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -11,7 +11,7 @@ from flask.views import MethodView as FlaskMethodView _NEEDS_METHOD_VIEW_CLEANUP = False if not hasattr(builtins, "MethodView"): - builtins.MethodView = FlaskMethodView + builtins.__dict__["MethodView"] = FlaskMethodView _NEEDS_METHOD_VIEW_CLEANUP = True from constants import HIDDEN_VALUE @@ -22,7 +22,7 @@ from controllers.console.extension import ( ) if _NEEDS_METHOD_VIEW_CLEANUP: - del builtins.MethodView + del builtins.__dict__["MethodView"] from models.account import AccountStatus from models.api_based_extension import APIBasedExtension @@ -242,5 +242,5 @@ def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeyp response, status = APIBasedExtensionDetailAPI().delete(extension_id) delete_mock.assert_called_once_with(existing_extension) - assert response == {"result": "success"} assert status == 204 + assert response == "" diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py index 1711aede614..0339c507776 100644 --- a/api/tests/unit_tests/controllers/console/test_feature.py +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -20,8 +20,10 @@ class TestFeatureApi: return_value=("account_id", "tenant_123"), ) - mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = { - "features": {"feature_a": True} + get_features = mocker.patch("controllers.console.feature.FeatureService.get_features") + get_features.return_value.model_dump.return_value = { + "features": {"feature_a": True}, + "vector_space": {"size": 1, "limit": 2}, } api = FeatureApi() @@ -30,6 +32,28 @@ class TestFeatureApi: result = raw_get(api) assert result == {"features": {"feature_a": True}} + get_features.assert_called_once_with("tenant_123", exclude_vector_space=True) + + +class TestFeatureVectorSpaceApi: + def test_get_vector_space_success(self, mocker: MockerFixture): + from controllers.console.feature import FeatureVectorSpaceApi + + mocker.patch( + "controllers.console.feature.current_account_with_tenant", + return_value=("account_id", "tenant_123"), + ) + + get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space") + get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480} + + api = FeatureVectorSpaceApi() + + raw_get = unwrap(FeatureVectorSpaceApi.get) + result = raw_get(api) + + assert result == {"size": 5120, "limit": 20480} + get_vector_space.assert_called_once_with("tenant_123") class TestSystemFeatureApi: diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index 1be402c8aba..8e86709b669 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -98,6 +98,28 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest get_mock.assert_not_called() +def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.GetRemoteFileInfo() + handler = _unwrap(api.get) + target_url = "http://example.com/api/aiagent/httpview/txt" + query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27" + + head_resp = _FakeResponse( + status_code=200, + headers={"Content-Type": "text/plain", "Content-Length": "128"}, + method="HEAD", + ) + head_mock = MagicMock(return_value=head_resp) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + + with app.test_request_context(f"/remote-files/{target_url}?{query}", method="GET"): + payload = handler(api, url=target_url) + + assert payload == {"file_type": "text/plain", "file_length": 128} + head_mock.assert_called_once_with(f"{target_url}?{query}") + + def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() handler = _unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index e836a3cc554..a81a8e1b1af 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -179,8 +179,8 @@ class TestModelProviderCredentialApi: ): result, status = method(api, provider="openai") - assert result["result"] == "success" assert status == 204 + assert result == "" class TestModelProviderCredentialSwitchApi: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index b7e24f92017..5db87df0a2b 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -67,7 +67,6 @@ class TestDatasetMetadataCreatePost: def _call_post(api, **kwargs): return _unwrap(api.post)(api, **kwargs) - @patch("controllers.service_api.dataset.metadata.marshal") @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @patch("controllers.service_api.dataset.metadata.current_user") @@ -76,7 +75,6 @@ class TestDatasetMetadataCreatePost: mock_current_user, mock_dataset_svc, mock_meta_svc, - mock_marshal, app: Flask, mock_tenant, mock_dataset, @@ -84,9 +82,8 @@ class TestDatasetMetadataCreatePost: """Test successful metadata creation.""" mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_metadata = Mock() + mock_metadata = {"id": "meta-1", "type": "string", "name": "Author"} mock_meta_svc.create_metadata.return_value = mock_metadata - mock_marshal.return_value = {"id": "meta-1", "name": "Author"} with app.test_request_context( f"/datasets/{mock_dataset.id}/metadata", @@ -101,6 +98,7 @@ class TestDatasetMetadataCreatePost: ) assert status == 201 + assert response == {"id": "meta-1", "type": "string", "name": "Author"} mock_meta_svc.create_metadata.assert_called_once() @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -143,7 +141,10 @@ class TestDatasetMetadataCreateGet: ): """Test successful metadata list retrieval.""" mock_dataset_svc.get_dataset.return_value = mock_dataset - mock_meta_svc.get_dataset_metadatas.return_value = [{"id": "m1"}] + mock_meta_svc.get_dataset_metadatas.return_value = { + "doc_metadata": [{"id": "m1", "name": "Author", "type": "string", "count": 0}], + "built_in_field_enabled": False, + } with app.test_request_context( f"/datasets/{mock_dataset.id}/metadata", @@ -156,6 +157,10 @@ class TestDatasetMetadataCreateGet: ) assert status == 200 + assert response == { + "doc_metadata": [{"id": "m1", "name": "Author", "type": "string", "count": 0}], + "built_in_field_enabled": False, + } @patch("controllers.service_api.dataset.metadata.DatasetService") def test_get_metadata_dataset_not_found( @@ -192,7 +197,6 @@ class TestDatasetMetadataServiceApiPatch: def _call_patch(api, **kwargs): return _unwrap(api.patch)(api, **kwargs) - @patch("controllers.service_api.dataset.metadata.marshal") @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @patch("controllers.service_api.dataset.metadata.current_user") @@ -201,7 +205,6 @@ class TestDatasetMetadataServiceApiPatch: mock_current_user, mock_dataset_svc, mock_meta_svc, - mock_marshal, app: Flask, mock_tenant, mock_dataset, @@ -210,8 +213,7 @@ class TestDatasetMetadataServiceApiPatch: metadata_id = str(uuid.uuid4()) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - mock_meta_svc.update_metadata_name.return_value = Mock() - mock_marshal.return_value = {"id": metadata_id, "name": "New Name"} + mock_meta_svc.update_metadata_name.return_value = {"id": metadata_id, "type": "string", "name": "New Name"} with app.test_request_context( f"/datasets/{mock_dataset.id}/metadata/{metadata_id}", @@ -227,6 +229,7 @@ class TestDatasetMetadataServiceApiPatch: ) assert status == 200 + assert response == {"id": metadata_id, "type": "string", "name": "New Name"} mock_meta_svc.update_metadata_name.assert_called_once() @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -357,7 +360,7 @@ class TestDatasetMetadataBuiltInFieldGet: ) assert status == 200 - assert "fields" in response + assert response == {"fields": [{"name": "source", "type": "string"}]} # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/controllers/web/test_remote_files.py b/api/tests/unit_tests/controllers/web/test_remote_files.py index 8554f440b75..93f0ca99447 100644 --- a/api/tests/unit_tests/controllers/web/test_remote_files.py +++ b/api/tests/unit_tests/controllers/web/test_remote_files.py @@ -2,6 +2,7 @@ from __future__ import annotations +import urllib.parse from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -36,6 +37,39 @@ class TestRemoteFileInfoApi: assert result["file_type"] == "application/pdf" assert result["file_length"] == 1024 + mock_proxy.head.assert_called_once_with("https://example.com/file.pdf") + + @patch("controllers.web.remote_files.ssrf_proxy") + def test_preserves_unencoded_target_query(self, mock_proxy: MagicMock, app: Flask) -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "text/plain", "Content-Length": "128"} + mock_proxy.head.return_value = mock_resp + + target_url = "http://example.com/api/aiagent/httpview/txt" + query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27" + + with app.test_request_context(f"/remote-files/{target_url}?{query}"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), target_url) + + assert result["file_type"] == "text/plain" + mock_proxy.head.assert_called_once_with(f"{target_url}?{query}") + + @patch("controllers.web.remote_files.ssrf_proxy") + def test_preserves_encoded_target_query(self, mock_proxy: MagicMock, app: Flask) -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "text/plain", "Content-Length": "128"} + mock_proxy.head.return_value = mock_resp + + target_url = "http://example.com/api/aiagent/httpview/txt?fileNameKey=cankao1" + encoded_url = urllib.parse.quote(target_url, safe="") + + with app.test_request_context(f"/remote-files/{encoded_url}"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), encoded_url) + + assert result["file_type"] == "text/plain" + mock_proxy.head.assert_called_once_with(target_url) @patch("controllers.web.remote_files.ssrf_proxy") def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None: diff --git a/api/tests/unit_tests/controllers/web/test_saved_message.py b/api/tests/unit_tests/controllers/web/test_saved_message.py index 3d558049127..5de740192fe 100644 --- a/api/tests/unit_tests/controllers/web/test_saved_message.py +++ b/api/tests/unit_tests/controllers/web/test_saved_message.py @@ -94,4 +94,4 @@ class TestSavedMessageApi: result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id) assert status == 204 - assert result["result"] == "success" + assert result == "" diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index d21b9e471b3..ecbd9691e98 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -183,7 +183,7 @@ class TestNotionExtractorPageRetrieval: } @patch("httpx.request") - def test_get_notion_block_data_simple_page(self, mock_request, extractor): + def test_get_notion_block_data_simple_page(self, mock_request, extractor: NotionExtractor): """Test retrieving simple page with basic blocks.""" # Arrange mock_data = { @@ -207,7 +207,7 @@ class TestNotionExtractorPageRetrieval: mock_request.assert_called_once() @patch("httpx.request") - def test_get_notion_block_data_with_headings(self, mock_request, extractor): + def test_get_notion_block_data_with_headings(self, mock_request, extractor: NotionExtractor): """Test retrieving page with heading blocks.""" # Arrange mock_data = { @@ -234,7 +234,7 @@ class TestNotionExtractorPageRetrieval: assert "### Sub-subtitle" in result[3] @patch("httpx.request") - def test_get_notion_block_data_with_pagination(self, mock_request, extractor): + def test_get_notion_block_data_with_pagination(self, mock_request, extractor: NotionExtractor): """Test retrieving page with paginated results.""" # Arrange first_page = { @@ -264,7 +264,7 @@ class TestNotionExtractorPageRetrieval: assert mock_request.call_count == 2 @patch("httpx.request") - def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with nested block structure.""" # Arrange # First call returns parent blocks @@ -300,7 +300,7 @@ class TestNotionExtractorPageRetrieval: assert mock_request.call_count == 2 @patch("httpx.request") - def test_get_notion_block_data_error_handling(self, mock_request, extractor): + def test_get_notion_block_data_error_handling(self, mock_request, extractor: NotionExtractor): """Test error handling for failed API requests.""" # Arrange mock_request.return_value = self._create_mock_response({}, status_code=404) @@ -311,7 +311,7 @@ class TestNotionExtractorPageRetrieval: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.request") - def test_get_notion_block_data_invalid_response(self, mock_request, extractor): + def test_get_notion_block_data_invalid_response(self, mock_request, extractor: NotionExtractor): """Test handling of invalid API response structure.""" # Arrange mock_request.return_value = self._create_mock_response({"invalid": "structure"}) @@ -322,7 +322,7 @@ class TestNotionExtractorPageRetrieval: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.request") - def test_get_notion_block_data_http_error(self, mock_request, extractor): + def test_get_notion_block_data_http_error(self, mock_request, extractor: NotionExtractor): """Test handling of HTTP errors during request.""" # Arrange mock_request.side_effect = httpx.HTTPError("Network error") @@ -368,7 +368,7 @@ class TestNotionExtractorDatabaseRetrieval: } @patch("httpx.post") - def test_get_notion_database_data_simple(self, mock_post, extractor): + def test_get_notion_database_data_simple(self, mock_post, extractor: NotionExtractor): """Test retrieving simple database with basic properties.""" # Arrange mock_response = Mock() @@ -407,7 +407,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Status:Done" in content @patch("httpx.post") - def test_get_notion_database_data_with_pagination(self, mock_post, extractor): + def test_get_notion_database_data_with_pagination(self, mock_post, extractor: NotionExtractor): """Test retrieving database with paginated results.""" # Arrange first_response = Mock() @@ -441,7 +441,7 @@ class TestNotionExtractorDatabaseRetrieval: assert mock_post.call_count == 2 @patch("httpx.post") - def test_get_notion_database_data_multi_select(self, mock_post, extractor): + def test_get_notion_database_data_multi_select(self, mock_post, extractor: NotionExtractor): """Test database with multi_select property type.""" # Arrange mock_response = Mock() @@ -474,7 +474,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Tags:" in content @patch("httpx.post") - def test_get_notion_database_data_empty_properties(self, mock_post, extractor): + def test_get_notion_database_data_empty_properties(self, mock_post, extractor: NotionExtractor): """Test database with empty property values.""" # Arrange mock_response = Mock() @@ -504,7 +504,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Row Page URL:" in content @patch("httpx.post") - def test_get_notion_database_data_empty_results(self, mock_post, extractor): + def test_get_notion_database_data_empty_results(self, mock_post, extractor: NotionExtractor): """Test handling of empty database.""" # Arrange mock_response = Mock() @@ -523,7 +523,7 @@ class TestNotionExtractorDatabaseRetrieval: assert len(result) == 0 @patch("httpx.post") - def test_get_notion_database_data_missing_results(self, mock_post, extractor): + def test_get_notion_database_data_missing_results(self, mock_post, extractor: NotionExtractor): """Test handling of malformed API response.""" # Arrange mock_response = Mock() @@ -559,7 +559,7 @@ class TestNotionExtractorTableParsing: ) @patch("httpx.request") - def test_read_table_rows_simple(self, mock_request, extractor): + def test_read_table_rows_simple(self, mock_request, extractor: NotionExtractor): """Test reading simple table with headers and rows.""" # Arrange mock_data = { @@ -611,7 +611,7 @@ class TestNotionExtractorTableParsing: assert "| Bob | 25 |" in result @patch("httpx.request") - def test_read_table_rows_with_empty_cells(self, mock_request, extractor): + def test_read_table_rows_with_empty_cells(self, mock_request, extractor: NotionExtractor): """Test reading table with empty cells.""" # Arrange mock_data = { @@ -643,7 +643,7 @@ class TestNotionExtractorTableParsing: assert "Value1" in result @patch("httpx.request") - def test_read_table_rows_with_pagination(self, mock_request, extractor): + def test_read_table_rows_with_pagination(self, mock_request, extractor: NotionExtractor): """Test reading table with paginated results.""" # Arrange first_page = { @@ -960,7 +960,7 @@ class TestNotionExtractorReadBlock: ) @patch("httpx.request") - def test_read_block_with_indentation(self, mock_request, extractor): + def test_read_block_with_indentation(self, mock_request, extractor: NotionExtractor): """Test reading nested blocks with proper indentation.""" # Arrange mock_data = { @@ -990,7 +990,7 @@ class TestNotionExtractorReadBlock: assert "\t\tNested content" in result @patch("httpx.request") - def test_read_block_skip_child_page(self, mock_request, extractor): + def test_read_block_skip_child_page(self, mock_request, extractor: NotionExtractor): """Test that child_page blocks don't recurse.""" # Arrange mock_data = { @@ -1139,7 +1139,7 @@ class TestNotionExtractorAdvancedBlockTypes: } @patch("httpx.request") - def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with bulleted and numbered list items. Both list types should be extracted with their content. @@ -1165,7 +1165,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Numbered item" in result[1] @patch("httpx.request") - def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with code, quote, and callout blocks. Special block types should preserve their content correctly. @@ -1193,7 +1193,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Important note" in result[2] @patch("httpx.request") - def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor): + def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor: NotionExtractor): """Test retrieving page with toggle block containing children. Toggle blocks can have nested content that should be extracted. @@ -1229,7 +1229,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Hidden content" in result[0] @patch("httpx.request") - def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor): + def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor: NotionExtractor): """Test retrieving page with mixed block types. Real Notion pages contain various block types mixed together. @@ -1308,7 +1308,7 @@ class TestNotionExtractorDatabaseAdvanced: } @patch("httpx.post") - def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor): + def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor: NotionExtractor): """Test database with multiple property types. Tests date, number, checkbox, URL, email, phone, and status properties. @@ -1354,7 +1354,7 @@ class TestNotionExtractorDatabaseAdvanced: assert "Status:Active" in content @patch("httpx.post") - def test_get_notion_database_data_large_pagination(self, mock_post, extractor): + def test_get_notion_database_data_large_pagination(self, mock_post, extractor: NotionExtractor): """Test database with multiple pages of results. Large databases require multiple API calls with cursor-based pagination. @@ -1415,7 +1415,7 @@ class TestNotionExtractorDatabaseAdvanced: assert mock_post.call_count == 3 @patch("httpx.post") - def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor): + def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor: NotionExtractor): """Test database with rich_text property type. Rich text properties can contain formatted text and should be extracted. @@ -1486,7 +1486,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_network_errors(self, mock_request, extractor, error_type, error_value): + def test_get_notion_block_data_network_errors( + self, mock_request, extractor: NotionExtractor, error_type, error_value + ): """Test handling of various network errors. Network issues (timeouts, connection failures) should raise appropriate errors. @@ -1509,7 +1511,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_http_status_errors(self, mock_request, extractor, status_code, description): + def test_get_notion_block_data_http_status_errors( + self, mock_request, extractor: NotionExtractor, status_code, description + ): """Test handling of various HTTP status errors. Different HTTP error codes (401, 403, 404, 429) should be handled appropriately. @@ -1534,7 +1538,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_malformed_responses(self, mock_request, extractor, response_data, description): + def test_get_notion_block_data_malformed_responses( + self, mock_request, extractor: NotionExtractor, response_data, description + ): """Test handling of malformed API responses. Various malformed responses should be handled gracefully. @@ -1551,7 +1557,7 @@ class TestNotionExtractorErrorScenarios: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.post") - def test_get_notion_database_data_with_query_filter(self, mock_post, extractor): + def test_get_notion_database_data_with_query_filter(self, mock_post, extractor: NotionExtractor): """Test database query with custom filter. Databases can be queried with filters to retrieve specific rows. @@ -1618,7 +1624,7 @@ class TestNotionExtractorTableAdvanced: ) @patch("httpx.request") - def test_read_table_rows_with_many_columns(self, mock_request, extractor): + def test_read_table_rows_with_many_columns(self, mock_request, extractor: NotionExtractor): """Test reading table with many columns. Tables can have numerous columns; all should be extracted correctly. diff --git a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py index d9fed9ae2ad..0659765da40 100644 --- a/api/tests/unit_tests/core/helper/test_ssrf_proxy.py +++ b/api/tests/unit_tests/core/helper/test_ssrf_proxy.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, call, patch import httpx import pytest @@ -6,6 +6,7 @@ import pytest from core.helper.ssrf_proxy import ( SSRF_DEFAULT_MAX_RETRIES, SSRFProxy, + _build_ssrf_client, _get_user_provided_host_header, _to_graphon_http_response, graphon_ssrf_proxy, @@ -41,6 +42,34 @@ def test_retry_exceed_max_retries(mock_get_client): assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com" +def test_build_ssrf_client_passes_ssl_verify_to_proxy_mount_transports(): + mock_client = MagicMock() + http_transport = MagicMock() + https_transport = MagicMock() + + with ( + patch("core.helper.ssrf_proxy.dify_config.SSRF_PROXY_ALL_URL", None), + patch("core.helper.ssrf_proxy.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy.example.com:8080"), + patch("core.helper.ssrf_proxy.dify_config.SSRF_PROXY_HTTPS_URL", "http://proxy.example.com:8443"), + patch("core.helper.ssrf_proxy.httpx.HTTPTransport", side_effect=[http_transport, https_transport]) as transport, + patch("core.helper.ssrf_proxy.httpx.Client", return_value=mock_client) as client, + ): + ssrf_client = _build_ssrf_client(verify=False) + + assert ssrf_client is mock_client + transport.assert_has_calls( + [ + call(proxy="http://proxy.example.com:8080", verify=False), + call(proxy="http://proxy.example.com:8443", verify=False), + ], + ) + client.assert_called_once_with( + mounts={"http://": http_transport, "https://": https_transport}, + verify=False, + limits=ANY, + ) + + class TestGetUserProvidedHostHeader: """Tests for _get_user_provided_host_header function.""" diff --git a/api/tests/unit_tests/core/moderation/test_output_moderation.py b/api/tests/unit_tests/core/moderation/test_output_moderation.py index 36a80cc76c2..ce384c4c13d 100644 --- a/api/tests/unit_tests/core/moderation/test_output_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_output_moderation.py @@ -19,22 +19,22 @@ class TestOutputModeration: return ModerationRule(type="keywords", config={"keywords": "badword"}) @pytest.fixture - def output_moderation(self, mock_queue_manager, moderation_rule): + def output_moderation(self, mock_queue_manager, moderation_rule: ModerationRule): return OutputModeration( tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager ) - def test_should_direct_output(self, output_moderation): + def test_should_direct_output(self, output_moderation: OutputModeration): assert output_moderation.should_direct_output() is False output_moderation.final_output = "blocked" assert output_moderation.should_direct_output() is True - def test_get_final_output(self, output_moderation): + def test_get_final_output(self, output_moderation: OutputModeration): assert output_moderation.get_final_output() == "" output_moderation.final_output = "blocked" assert output_moderation.get_final_output() == "blocked" - def test_append_new_token(self, output_moderation): + def test_append_new_token(self, output_moderation: OutputModeration): with patch.object(OutputModeration, "start_thread") as mock_start: output_moderation.append_new_token("hello") assert output_moderation.buffer == "hello" @@ -45,7 +45,7 @@ class TestOutputModeration: assert output_moderation.buffer == "hello world" assert mock_start.call_count == 1 - def test_moderation_completion_no_flag(self, output_moderation): + def test_moderation_completion_no_flag(self, output_moderation: OutputModeration): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) @@ -55,7 +55,7 @@ class TestOutputModeration: assert flagged is False assert output_moderation.is_final_chunk is True - def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager): + def test_moderation_completion_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult( flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" @@ -71,7 +71,7 @@ class TestOutputModeration: assert args[0].text == "preset" assert args[1] == PublishFrom.TASK_PIPELINE - def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager): + def test_moderation_completion_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult( flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content" @@ -85,7 +85,7 @@ class TestOutputModeration: args, _ = mock_queue_manager.publish.call_args assert args[0].text == "masked content" - def test_start_thread(self, output_moderation): + def test_start_thread(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch("core.moderation.output_moderation.current_app") as mock_current_app: mock_current_app._get_current_object = MagicMock(return_value=mock_app) @@ -99,7 +99,7 @@ class TestOutputModeration: mock_thread_class.assert_called_once() mock_thread_instance.start.assert_called_once() - def test_stop_thread(self, output_moderation): + def test_stop_thread(self, output_moderation: OutputModeration): mock_thread = MagicMock() mock_thread.is_alive.return_value = True output_moderation.thread = mock_thread @@ -113,7 +113,7 @@ class TestOutputModeration: assert output_moderation.thread_running is True @patch("core.moderation.output_moderation.ModerationFactory") - def test_moderation_success(self, mock_factory_class, output_moderation): + def test_moderation_success(self, mock_factory_class, output_moderation: OutputModeration): mock_factory = mock_factory_class.return_value mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) mock_factory.moderation_for_outputs.return_value = mock_result @@ -126,13 +126,13 @@ class TestOutputModeration: ) @patch("core.moderation.output_moderation.ModerationFactory") - def test_moderation_exception(self, mock_factory_class, output_moderation): + def test_moderation_exception(self, mock_factory_class, output_moderation: OutputModeration): mock_factory_class.side_effect = Exception("error") result = output_moderation.moderation("tenant", "app", "buffer") assert result is None - def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager): + def test_worker_loop_and_exit(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) # Test exit on thread_running=False @@ -140,7 +140,7 @@ class TestOutputModeration: output_moderation.worker(mock_app, 10) # Should exit immediately - def test_worker_no_flag(self, output_moderation): + def test_worker_no_flag(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -160,7 +160,7 @@ class TestOutputModeration: assert mock_moderation.called - def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager): + def test_worker_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -177,7 +177,7 @@ class TestOutputModeration: mock_queue_manager.publish.assert_called_once() # It breaks on DIRECT_OUTPUT - def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager): + def test_worker_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -199,7 +199,7 @@ class TestOutputModeration: args, _ = mock_queue_manager.publish.call_args assert args[0].text == "masked" - def test_worker_chunk_too_small(self, output_moderation): + def test_worker_chunk_too_small(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch("time.sleep") as mock_sleep: # chunk_length < buffer_size and not is_final_chunk @@ -215,7 +215,7 @@ class TestOutputModeration: mock_sleep.assert_called_once_with(1) - def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager): + def test_worker_empty_not_flagged(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: # Return None (exception or no rule) diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 33a32936825..704f5d362c6 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -262,7 +262,7 @@ def workflow_repo_fixture(monkeypatch: pytest.MonkeyPatch): @pytest.fixture -def trace_task_message(monkeypatch, mock_db): +def trace_task_message(monkeypatch: pytest.MonkeyPatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) @@ -353,7 +353,7 @@ def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch: pytest.Mo assert OpsTraceManager.get_ops_trace_instance("app-id") is None -def test_get_ops_trace_instance_success(monkeypatch, mock_db): +def test_get_ops_trace_instance_success(monkeypatch: pytest.MonkeyPatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) mock_db.get.return_value = app monkeypatch.setattr( @@ -497,7 +497,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): assert result.documents == [{"doc": "value"}] -def test_trace_task_tool_trace(monkeypatch, mock_db): +def test_trace_task_tool_trace(monkeypatch: pytest.MonkeyPatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) configure_db_scalar(mock_db, message_file=FakeMessageFile()) diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index e7f7cabecd2..ea60b94b612 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -87,6 +87,7 @@ from uuid import uuid4 import pytest from flask import Flask +from flask.testing import FlaskClient from flask_restx import Api from controllers.console.datasets.datasets import DatasetApi, DatasetListApi @@ -339,7 +340,7 @@ class TestDatasetListApi: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_get_datasets_success(self, client, mock_current_user): + def test_get_datasets_success(self, client: FlaskClient, mock_current_user): """ Test successful retrieval of dataset list. @@ -380,7 +381,7 @@ class TestDatasetListApi: # Verify service was called mock_get_datasets.assert_called_once() - def test_get_datasets_with_search(self, client, mock_current_user): + def test_get_datasets_with_search(self, client: FlaskClient, mock_current_user): """ Test dataset listing with search keyword. @@ -410,7 +411,7 @@ class TestDatasetListApi: call_args = mock_get_datasets.call_args assert call_args[1]["search"] == search_keyword - def test_get_datasets_with_pagination(self, client, mock_current_user): + def test_get_datasets_with_pagination(self, client: FlaskClient, mock_current_user): """ Test dataset listing with pagination parameters. @@ -495,7 +496,7 @@ class TestDatasetApiGet: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_get_dataset_success(self, client, mock_current_user): + def test_get_dataset_success(self, client: FlaskClient, mock_current_user): """ Test successful retrieval of a single dataset. @@ -533,7 +534,7 @@ class TestDatasetApiGet: mock_get_dataset.assert_called_once_with(dataset_id) mock_check_perm.assert_called_once() - def test_get_dataset_not_found(self, client, mock_current_user): + def test_get_dataset_not_found(self, client: FlaskClient, mock_current_user): """ Test error handling when dataset is not found. @@ -611,7 +612,7 @@ class TestDatasetApiCreate: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_create_dataset_success(self, client, mock_current_user): + def test_create_dataset_success(self, client: FlaskClient, mock_current_user): """ Test successful creation of a dataset. @@ -706,7 +707,7 @@ class TestHitTestingApi: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_hit_testing_success(self, client, mock_current_user): + def test_hit_testing_success(self, client: FlaskClient, mock_current_user): """ Test successful hit testing operation. diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index 36592196c69..e7a195a4727 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -313,6 +313,54 @@ class TestBillingServiceSubscriptionInfo: assert result == expected_response mock_send_request.assert_called_once_with("GET", "/subscription/info", params={"tenant_id": tenant_id}) + def test_get_info_exclude_vector_space(self, mock_send_request): + """When requested, get_info asks billing to skip vector_space.""" + # Arrange + tenant_id = "tenant-123" + expected_response = { + "enabled": True, + "subscription": {"plan": "professional", "interval": "month", "education": False}, + "members": {"size": 1, "limit": 50}, + "apps": {"size": 1, "limit": 200}, + "knowledge_rate_limit": {"limit": 1000}, + "documents_upload_quota": {"size": 0, "limit": 1000}, + "annotation_quota_limit": {"size": 0, "limit": 5000}, + "docs_processing": "top-priority", + "can_replace_logo": True, + "model_load_balancing_enabled": True, + "knowledge_pipeline_publish_enabled": True, + } + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_info(tenant_id, exclude_vector_space=True) + + # Assert + assert "vector_space" not in result + mock_send_request.assert_called_once_with( + "GET", + "/subscription/info", + params={"tenant_id": tenant_id, "exclude_vector_space": "true"}, + ) + + def test_get_vector_space_success(self, mock_send_request): + """Test successful retrieval of vector-space usage and limit.""" + # Arrange + tenant_id = "tenant-123" + expected_response = {"size": 5120.75, "limit": 20480} + mock_send_request.return_value = expected_response + + # Act + result = BillingService.get_vector_space(tenant_id) + + # Assert + assert result == expected_response + mock_send_request.assert_called_once_with( + "GET", + "/subscription/vector-space", + params={"tenant_id": tenant_id}, + ) + def test_get_knowledge_rate_limit_with_defaults(self, mock_send_request): """Test knowledge rate limit retrieval with default values.""" # Arrange @@ -1744,8 +1792,9 @@ class TestBillingServiceSubscriptionInfoDataType: assert isinstance(result["apps"]["size"], int) assert isinstance(result["apps"]["limit"], int) - assert isinstance(result["vector_space"]["size"], float) - assert isinstance(result["vector_space"]["limit"], int) + if "vector_space" in result: + assert isinstance(result["vector_space"]["size"], float) + assert isinstance(result["vector_space"]["limit"], int) assert isinstance(result["knowledge_rate_limit"]["limit"], int) @@ -1783,11 +1832,13 @@ class TestBillingServiceSubscriptionInfoDataType: def test_get_info_without_optional_fields(self, mock_send_request, string_billing_response): """NotRequired fields can be absent without raising.""" del string_billing_response["next_credit_reset_date"] + del string_billing_response["vector_space"] mock_send_request.return_value = string_billing_response result = BillingService.get_info("tenant-type-test") assert "next_credit_reset_date" not in result + assert "vector_space" not in result self._assert_billing_info_types(result) def test_get_info_with_extra_fields(self, mock_send_request, string_billing_response): diff --git a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py index ab141a7b2d8..8614d351f19 100644 --- a/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py +++ b/api/tests/unit_tests/services/test_feature_service_human_input_email_delivery.py @@ -102,3 +102,17 @@ def test_resolve_human_input_email_delivery_enabled_matrix( ) assert result is case.expected + + +def test_get_vector_space_converts_billing_float_size(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", True) + monkeypatch.setattr( + feature_service_module.BillingService, + "get_vector_space", + lambda tenant_id: {"size": 5120.75, "limit": 20480}, + ) + + result = FeatureService.get_vector_space("tenant-1") + + assert result.size == 5120 + assert result.limit == 20480 diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index 69bd194a684..2e6ca7dbb9c 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -52,7 +52,7 @@ class TestFileService: @patch("services.file_service.extract_tenant_id") @patch("services.file_service.file_helpers.get_signed_file_url") def test_upload_file_success( - self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session + self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service: FileService, mock_db_session ): # Setup mock_tenant_id.return_value = "tenant_id" @@ -88,7 +88,7 @@ class TestFileService: with pytest.raises(ValueError, match="Filename contains invalid characters"): file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock()) - def test_upload_file_long_filename(self, file_service, mock_db_session): + def test_upload_file_long_filename(self, file_service: FileService, mock_db_session): # Setup long_name = "a" * 210 + ".txt" user = MagicMock(spec=Account) @@ -124,7 +124,7 @@ class TestFileService: with pytest.raises(FileTooLargeError): file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock()) - def test_upload_file_end_user(self, file_service, mock_db_session): + def test_upload_file_end_user(self, file_service: FileService, mock_db_session): user = MagicMock(spec=EndUser) user.id = "end_user_id" @@ -160,7 +160,7 @@ class TestFileService: assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False - def test_get_file_base64_success(self, file_service, mock_db_session): + def test_get_file_base64_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -177,12 +177,12 @@ class TestFileService: assert result == base64.b64encode(b"test content").decode() mock_storage.load_once.assert_called_once_with("test_key") - def test_get_file_base64_not_found(self, file_service, mock_db_session): + def test_get_file_base64_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_base64("non_existent") - def test_upload_text_success(self, file_service, mock_db_session): + def test_upload_text_success(self, file_service: FileService, mock_db_session): # Setup text = "sample text" text_name = "test.txt" @@ -204,13 +204,13 @@ class TestFileService: mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() - def test_upload_text_long_name(self, file_service, mock_db_session): + def test_upload_text_long_name(self, file_service: FileService, mock_db_session): long_name = "a" * 210 with patch("services.file_service.storage"): result = file_service.upload_text("text", long_name, "user", "tenant") assert len(result.name) == 200 - def test_get_file_preview_success(self, file_service, mock_db_session): + def test_get_file_preview_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -226,12 +226,12 @@ class TestFileService: # Assert assert result == "Extracted text content" - def test_get_file_preview_not_found(self, file_service, mock_db_session): + def test_get_file_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_preview("non_existent", "tenant_id") - def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_file_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "exe" @@ -239,7 +239,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_file_preview("file_id", "tenant_id") - def test_get_image_preview_success(self, file_service, mock_db_session): + def test_get_image_preview_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -268,14 +268,14 @@ class TestFileService: with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_image_preview_not_found(self, file_service, mock_db_session): + def test_get_image_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_image_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_image_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" @@ -285,7 +285,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session): + def test_get_file_generator_by_file_id_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -308,14 +308,14 @@ class TestFileService: with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") - def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): + def test_get_file_generator_by_file_id_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") - def test_get_public_image_preview_success(self, file_service, mock_db_session): + def test_get_public_image_preview_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "png" @@ -329,12 +329,12 @@ class TestFileService: assert gen == b"image content" assert mime == "image/png" - def test_get_public_image_preview_not_found(self, file_service, mock_db_session): + def test_get_public_image_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_public_image_preview("file_id") - def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_public_image_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" @@ -342,7 +342,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_public_image_preview("file_id") - def test_get_file_content_success(self, file_service, mock_db_session): + def test_get_file_content_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -353,12 +353,12 @@ class TestFileService: result = file_service.get_file_content("file_id") assert result == "hello world" - def test_get_file_content_not_found(self, file_service, mock_db_session): + def test_get_file_content_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_content("file_id") - def test_delete_file_success(self, file_service, mock_db_session): + def test_delete_file_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -370,7 +370,7 @@ class TestFileService: mock_storage.delete.assert_called_once_with("key") mock_db_session.delete.assert_called_once_with(upload_file) - def test_delete_file_not_found(self, file_service, mock_db_session): + def test_delete_file_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None file_service.delete_file("file_id") # Should return without doing anything diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index b4332334aba..7ce897eb029 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -151,9 +151,9 @@ class TestErrorHandling: def test_clean_dataset_task_rollback_failure_still_closes_session( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -198,9 +198,9 @@ class TestPipelineAndWorkflowDeletion: def test_clean_dataset_task_with_pipeline_id( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, pipeline_id, mock_db_session, mock_storage, @@ -231,9 +231,9 @@ class TestPipelineAndWorkflowDeletion: def test_clean_dataset_task_without_pipeline_id( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -271,9 +271,9 @@ class TestSegmentAttachmentCleanup: def test_clean_dataset_task_with_attachments( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -321,9 +321,9 @@ class TestSegmentAttachmentCleanup: def test_clean_dataset_task_attachment_storage_failure( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -375,9 +375,9 @@ class TestEdgeCases: def test_clean_dataset_task_session_always_closed( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -413,9 +413,9 @@ class TestIndexProcessorParameters: def test_clean_dataset_task_passes_correct_parameters_to_index_processor( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 41d3068a103..e5782899e3b 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -124,10 +124,10 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, - notion_workspace_id, - notion_page_id, + dataset_id: str, + document_id: str, + notion_workspace_id: str, + notion_page_id: str, ): """Test that NotionExtractor is initialized with expected arguments.""" # Arrange @@ -151,9 +151,9 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, - credential_id, + dataset_id: str, + document_id: str, + credential_id: str, ): """Test that datasource credentials are requested with expected identifiers.""" # Arrange @@ -176,8 +176,8 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, + dataset_id: str, + document_id: str, ): """Test that missing credential_id is forwarded as None.""" # Arrange @@ -212,8 +212,8 @@ class TestDataSourceInfoSerialization: self, mock_document, mock_dataset, - dataset_id, - document_id, + dataset_id: str, + document_id: str, ): """data_source_info must be serialized with json.dumps before DB write.""" with ( diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 2bcd7d8b78c..ba78fc1df1d 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -2340,9 +2340,6 @@ } }, "web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 2 } diff --git a/packages/contracts/generated/api/console/datasets/orpc.gen.ts b/packages/contracts/generated/api/console/datasets/orpc.gen.ts index 4623d369c8d..baec823590b 100644 --- a/packages/contracts/generated/api/console/datasets/orpc.gen.ts +++ b/packages/contracts/generated/api/console/datasets/orpc.gen.ts @@ -495,16 +495,8 @@ export const init = { post: post6, } -/** - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated - */ export const get7 = oc .route({ - deprecated: true, - description: - 'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', inputStructure: 'detailed', method: 'GET', operationId: 'getDatasetsMetadataBuiltIn', @@ -799,6 +791,7 @@ export const post11 = oc method: 'POST', operationId: 'postDatasetsByDatasetIdDocumentsMetadata', path: '/datasets/{dataset_id}/documents/metadata', + successStatus: 204, tags: ['console'], }) .input( @@ -1643,6 +1636,7 @@ export const post19 = oc method: 'POST', operationId: 'postDatasetsByDatasetIdMetadataBuiltInByAction', path: '/datasets/{dataset_id}/metadata/built-in/{action}', + successStatus: 204, tags: ['console'], }) .input(z.object({ params: zPostDatasetsByDatasetIdMetadataBuiltInByActionPath })) @@ -1668,16 +1662,8 @@ export const delete8 = oc .input(z.object({ params: zDeleteDatasetsByDatasetIdMetadataByMetadataIdPath })) .output(zDeleteDatasetsByDatasetIdMetadataByMetadataIdResponse) -/** - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated - */ export const patch10 = oc .route({ - deprecated: true, - description: - 'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', inputStructure: 'detailed', method: 'PATCH', operationId: 'patchDatasetsByDatasetIdMetadataByMetadataId', @@ -1697,16 +1683,8 @@ export const byMetadataId = { patch: patch10, } -/** - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated - */ export const get29 = oc .route({ - deprecated: true, - description: - 'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', inputStructure: 'detailed', method: 'GET', operationId: 'getDatasetsByDatasetIdMetadata', @@ -1716,20 +1694,13 @@ export const get29 = oc .input(z.object({ params: zGetDatasetsByDatasetIdMetadataPath })) .output(zGetDatasetsByDatasetIdMetadataResponse) -/** - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated - */ export const post20 = oc .route({ - deprecated: true, - description: - 'Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', inputStructure: 'detailed', method: 'POST', operationId: 'postDatasetsByDatasetIdMetadata', path: '/datasets/{dataset_id}/metadata', + successStatus: 201, tags: ['console'], }) .input( diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 2d9f4945153..92df8a52a50 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -140,6 +140,10 @@ export type DatasetAndDocumentResponse = { documents: Array } +export type DatasetMetadataBuiltInFieldsResponse = { + fields: Array +} + export type TextContentResponse = { content: string } @@ -291,11 +295,22 @@ export type HitTestingResponse = { records?: Array } +export type DatasetMetadataListResponse = { + built_in_field_enabled: boolean + doc_metadata: Array +} + export type MetadataArgs = { name: string type: 'number' | 'string' | 'time' } +export type DatasetMetadataResponse = { + id: string + name: string + type: string +} + export type MetadataUpdatePayload = { name: string } @@ -409,6 +424,11 @@ export type DatasetResponse = { permission?: string | null } +export type DatasetMetadataBuiltInFieldResponse = { + name: string + type: string +} + export type DocumentMetadataOperation = { document_id: string metadata_list: Array @@ -431,6 +451,13 @@ export type HitTestingRecord = { tsne_position?: unknown } +export type DatasetMetadataListItemResponse = { + count?: number + id: string + name: string + type: string +} + export type DatasetContent = { content?: string content_type?: string @@ -977,9 +1004,7 @@ export type GetDatasetsMetadataBuiltInData = { } export type GetDatasetsMetadataBuiltInResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataBuiltInFieldsResponse } export type GetDatasetsMetadataBuiltInResponse @@ -1341,7 +1366,9 @@ export type PostDatasetsByDatasetIdDocumentsMetadataData = { } export type PostDatasetsByDatasetIdDocumentsMetadataResponses = { - 200: SimpleResultResponse + 204: { + [key: string]: never + } } export type PostDatasetsByDatasetIdDocumentsMetadataResponse @@ -2052,9 +2079,7 @@ export type GetDatasetsByDatasetIdMetadataData = { } export type GetDatasetsByDatasetIdMetadataResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataListResponse } export type GetDatasetsByDatasetIdMetadataResponse @@ -2070,9 +2095,7 @@ export type PostDatasetsByDatasetIdMetadataData = { } export type PostDatasetsByDatasetIdMetadataResponses = { - 200: { - [key: string]: unknown - } + 201: DatasetMetadataResponse } export type PostDatasetsByDatasetIdMetadataResponse @@ -2089,7 +2112,9 @@ export type PostDatasetsByDatasetIdMetadataBuiltInByActionData = { } export type PostDatasetsByDatasetIdMetadataBuiltInByActionResponses = { - 200: SimpleResultResponse + 204: { + [key: string]: never + } } export type PostDatasetsByDatasetIdMetadataBuiltInByActionResponse @@ -2125,9 +2150,7 @@ export type PatchDatasetsByDatasetIdMetadataByMetadataIdData = { } export type PatchDatasetsByDatasetIdMetadataByMetadataIdResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataResponse } export type PatchDatasetsByDatasetIdMetadataByMetadataIdResponse diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index 18d585a247e..984608ffcc5 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -204,6 +204,15 @@ export const zMetadataArgs = z.object({ type: z.enum(['number', 'string', 'time']), }) +/** + * DatasetMetadataResponse + */ +export const zDatasetMetadataResponse = z.object({ + id: z.string(), + name: z.string(), + type: z.string(), +}) + /** * MetadataUpdatePayload */ @@ -319,6 +328,21 @@ export const zDatasetResponse = z.object({ permission: z.string().nullish(), }) +/** + * DatasetMetadataBuiltInFieldResponse + */ +export const zDatasetMetadataBuiltInFieldResponse = z.object({ + name: z.string(), + type: z.string(), +}) + +/** + * DatasetMetadataBuiltInFieldsResponse + */ +export const zDatasetMetadataBuiltInFieldsResponse = z.object({ + fields: z.array(zDatasetMetadataBuiltInFieldResponse), +}) + /** * DocumentMetadataResponse */ @@ -368,6 +392,24 @@ export const zDatasetAndDocumentResponse = z.object({ documents: z.array(zDocumentResponse), }) +/** + * DatasetMetadataListItemResponse + */ +export const zDatasetMetadataListItemResponse = z.object({ + count: z.int().optional().default(0), + id: z.string(), + name: z.string(), + type: z.string(), +}) + +/** + * DatasetMetadataListResponse + */ +export const zDatasetMetadataListResponse = z.object({ + built_in_field_enabled: z.boolean(), + doc_metadata: z.array(zDatasetMetadataListItemResponse), +}) + export const zAppDetailKernel = z.object({ description: z.string().optional(), icon: z.string().optional(), @@ -966,9 +1008,9 @@ export const zPostDatasetsInitBody = zKnowledgeConfig export const zPostDatasetsInitResponse = zDatasetAndDocumentResponse /** - * Success + * Built-in fields retrieved successfully */ -export const zGetDatasetsMetadataBuiltInResponse = z.record(z.string(), z.unknown()) +export const zGetDatasetsMetadataBuiltInResponse = zDatasetMetadataBuiltInFieldsResponse /** * Success @@ -1149,9 +1191,9 @@ export const zPostDatasetsByDatasetIdDocumentsMetadataPath = z.object({ }) /** - * Success + * Documents metadata updated successfully */ -export const zPostDatasetsByDatasetIdDocumentsMetadataResponse = zSimpleResultResponse +export const zPostDatasetsByDatasetIdDocumentsMetadataResponse = z.record(z.string(), z.never()) export const zPatchDatasetsByDatasetIdDocumentsStatusByActionBatchPath = z.object({ action: z.string(), @@ -1566,9 +1608,9 @@ export const zGetDatasetsByDatasetIdMetadataPath = z.object({ }) /** - * Success + * Metadata retrieved successfully */ -export const zGetDatasetsByDatasetIdMetadataResponse = z.record(z.string(), z.unknown()) +export const zGetDatasetsByDatasetIdMetadataResponse = zDatasetMetadataListResponse export const zPostDatasetsByDatasetIdMetadataBody = zMetadataArgs @@ -1577,9 +1619,9 @@ export const zPostDatasetsByDatasetIdMetadataPath = z.object({ }) /** - * Success + * Metadata created successfully */ -export const zPostDatasetsByDatasetIdMetadataResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdMetadataResponse = zDatasetMetadataResponse export const zPostDatasetsByDatasetIdMetadataBuiltInByActionPath = z.object({ action: z.string(), @@ -1587,9 +1629,12 @@ export const zPostDatasetsByDatasetIdMetadataBuiltInByActionPath = z.object({ }) /** - * Success + * Action completed successfully */ -export const zPostDatasetsByDatasetIdMetadataBuiltInByActionResponse = zSimpleResultResponse +export const zPostDatasetsByDatasetIdMetadataBuiltInByActionResponse = z.record( + z.string(), + z.never(), +) export const zDeleteDatasetsByDatasetIdMetadataByMetadataIdPath = z.object({ dataset_id: z.string(), @@ -1612,12 +1657,9 @@ export const zPatchDatasetsByDatasetIdMetadataByMetadataIdPath = z.object({ }) /** - * Success + * Metadata updated successfully */ -export const zPatchDatasetsByDatasetIdMetadataByMetadataIdResponse = z.record( - z.string(), - z.unknown(), -) +export const zPatchDatasetsByDatasetIdMetadataByMetadataIdResponse = zDatasetMetadataResponse export const zGetDatasetsByDatasetIdNotionSyncPath = z.object({ dataset_id: z.string(), diff --git a/packages/contracts/generated/api/console/features/orpc.gen.ts b/packages/contracts/generated/api/console/features/orpc.gen.ts index e24ec3d9642..3463ccb015e 100644 --- a/packages/contracts/generated/api/console/features/orpc.gen.ts +++ b/packages/contracts/generated/api/console/features/orpc.gen.ts @@ -2,14 +2,35 @@ import { oc } from '@orpc/contract' -import { zGetFeaturesResponse } from './zod.gen' +import { zGetFeaturesResponse, zGetFeaturesVectorSpaceResponse } from './zod.gen' + +/** + * Get vector-space usage and limit for current tenant + * + * Get vector-space usage and limit for current tenant + */ +export const get = oc + .route({ + description: 'Get vector-space usage and limit for current tenant', + inputStructure: 'detailed', + method: 'GET', + operationId: 'getFeaturesVectorSpace', + path: '/features/vector-space', + summary: 'Get vector-space usage and limit for current tenant', + tags: ['console'], + }) + .output(zGetFeaturesVectorSpaceResponse) + +export const vectorSpace = { + get, +} /** * Get feature configuration for current tenant * * Get feature configuration for current tenant */ -export const get = oc +export const get2 = oc .route({ description: 'Get feature configuration for current tenant', inputStructure: 'detailed', @@ -22,7 +43,8 @@ export const get = oc .output(zGetFeaturesResponse) export const features = { - get, + get: get2, + vectorSpace, } export const contract = { diff --git a/packages/contracts/generated/api/console/features/types.gen.ts b/packages/contracts/generated/api/console/features/types.gen.ts index 411e062afb2..68b2dc0d9eb 100644 --- a/packages/contracts/generated/api/console/features/types.gen.ts +++ b/packages/contracts/generated/api/console/features/types.gen.ts @@ -75,3 +75,17 @@ export type GetFeaturesResponses = { } export type GetFeaturesResponse = GetFeaturesResponses[keyof GetFeaturesResponses] + +export type GetFeaturesVectorSpaceData = { + body?: never + path?: never + query?: never + url: '/features/vector-space' +} + +export type GetFeaturesVectorSpaceResponses = { + 200: LimitationModel +} + +export type GetFeaturesVectorSpaceResponse + = GetFeaturesVectorSpaceResponses[keyof GetFeaturesVectorSpaceResponses] diff --git a/packages/contracts/generated/api/console/features/zod.gen.ts b/packages/contracts/generated/api/console/features/zod.gen.ts index 9ace83a4335..0e26f296b60 100644 --- a/packages/contracts/generated/api/console/features/zod.gen.ts +++ b/packages/contracts/generated/api/console/features/zod.gen.ts @@ -93,3 +93,8 @@ export const zFeatureModel = z.object({ * Success */ export const zGetFeaturesResponse = zFeatureModel + +/** + * Success + */ +export const zGetFeaturesVectorSpaceResponse = zLimitationModel diff --git a/packages/contracts/generated/api/service/orpc.gen.ts b/packages/contracts/generated/api/service/orpc.gen.ts index 56471b73b9f..33d2c473614 100644 --- a/packages/contracts/generated/api/service/orpc.gen.ts +++ b/packages/contracts/generated/api/service/orpc.gen.ts @@ -1634,16 +1634,10 @@ export const byAction3 = { * Get all built-in metadata fields * * Get all built-in metadata fields - * - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated */ export const get15 = oc .route({ - deprecated: true, - description: - 'Get all built-in metadata fields\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + description: 'Get all built-in metadata fields', inputStructure: 'detailed', method: 'GET', operationId: 'getDatasetsByDatasetIdMetadataBuiltIn', @@ -1682,16 +1676,10 @@ export const delete7 = oc * Update metadata name * * Update metadata name - * - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated */ export const patch5 = oc .route({ - deprecated: true, - description: - 'Update metadata name\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + description: 'Update metadata name', inputStructure: 'detailed', method: 'PATCH', operationId: 'patchDatasetsByDatasetIdMetadataByMetadataId', @@ -1716,16 +1704,10 @@ export const byMetadataId = { * Get all metadata for a dataset * * Get all metadata for a dataset - * - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated */ export const get16 = oc .route({ - deprecated: true, - description: - 'Get all metadata for a dataset\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + description: 'Get all metadata for a dataset', inputStructure: 'detailed', method: 'GET', operationId: 'getDatasetsByDatasetIdMetadata', @@ -1740,16 +1722,10 @@ export const get16 = oc * Create metadata for a dataset * * Create metadata for a dataset - * - * Generated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate. - * - * @deprecated */ export const post28 = oc .route({ - deprecated: true, - description: - 'Create metadata for a dataset\n\nGenerated contract types may be inaccurate because backend OpenAPI annotations are incomplete. Do not migrate callers until the generated contract is accurate.', + description: 'Create metadata for a dataset', inputStructure: 'detailed', method: 'POST', operationId: 'postDatasetsByDatasetIdMetadata', diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index 9918edcb2a4..101be40c8cd 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -172,6 +172,37 @@ export type DatasetListQuery = { tag_ids?: Array } +export type DatasetMetadataActionResponse = { + result: string +} + +export type DatasetMetadataBuiltInFieldResponse = { + name: string + type: string +} + +export type DatasetMetadataBuiltInFieldsResponse = { + fields: Array +} + +export type DatasetMetadataListItemResponse = { + count?: number + id: string + name: string + type: string +} + +export type DatasetMetadataListResponse = { + built_in_field_enabled: boolean + doc_metadata: Array +} + +export type DatasetMetadataResponse = { + id: string + name: string + type: string +} + export type DatasetPermissionEnum = 'all_team_members' | 'only_me' | 'partial_members' export type DatasetUpdatePayload = { @@ -1666,7 +1697,7 @@ export type PostDatasetsByDatasetIdDocumentsMetadataError = PostDatasetsByDatasetIdDocumentsMetadataErrors[keyof PostDatasetsByDatasetIdDocumentsMetadataErrors] export type PostDatasetsByDatasetIdDocumentsMetadataResponses = { - 200: SimpleResultResponse + 200: DatasetMetadataActionResponse } export type PostDatasetsByDatasetIdDocumentsMetadataResponse @@ -2349,9 +2380,7 @@ export type GetDatasetsByDatasetIdMetadataError = GetDatasetsByDatasetIdMetadataErrors[keyof GetDatasetsByDatasetIdMetadataErrors] export type GetDatasetsByDatasetIdMetadataResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataListResponse } export type GetDatasetsByDatasetIdMetadataResponse @@ -2379,9 +2408,7 @@ export type PostDatasetsByDatasetIdMetadataError = PostDatasetsByDatasetIdMetadataErrors[keyof PostDatasetsByDatasetIdMetadataErrors] export type PostDatasetsByDatasetIdMetadataResponses = { - 201: { - [key: string]: unknown - } + 201: DatasetMetadataResponse } export type PostDatasetsByDatasetIdMetadataResponse @@ -2406,9 +2433,7 @@ export type GetDatasetsByDatasetIdMetadataBuiltInError = GetDatasetsByDatasetIdMetadataBuiltInErrors[keyof GetDatasetsByDatasetIdMetadataBuiltInErrors] export type GetDatasetsByDatasetIdMetadataBuiltInResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataBuiltInFieldsResponse } export type GetDatasetsByDatasetIdMetadataBuiltInResponse @@ -2437,7 +2462,7 @@ export type PostDatasetsByDatasetIdMetadataBuiltInByActionError = PostDatasetsByDatasetIdMetadataBuiltInByActionErrors[keyof PostDatasetsByDatasetIdMetadataBuiltInByActionErrors] export type PostDatasetsByDatasetIdMetadataBuiltInByActionResponses = { - 200: SimpleResultResponse + 200: DatasetMetadataActionResponse } export type PostDatasetsByDatasetIdMetadataBuiltInByActionResponse @@ -2497,9 +2522,7 @@ export type PatchDatasetsByDatasetIdMetadataByMetadataIdError = PatchDatasetsByDatasetIdMetadataByMetadataIdErrors[keyof PatchDatasetsByDatasetIdMetadataByMetadataIdErrors] export type PatchDatasetsByDatasetIdMetadataByMetadataIdResponses = { - 200: { - [key: string]: unknown - } + 200: DatasetMetadataResponse } export type PatchDatasetsByDatasetIdMetadataByMetadataIdResponse diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index f99d3f08a1d..3bdaea39767 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -209,6 +209,55 @@ export const zDatasetListQuery = z.object({ tag_ids: z.array(z.string()).optional(), }) +/** + * DatasetMetadataActionResponse + */ +export const zDatasetMetadataActionResponse = z.object({ + result: z.string(), +}) + +/** + * DatasetMetadataBuiltInFieldResponse + */ +export const zDatasetMetadataBuiltInFieldResponse = z.object({ + name: z.string(), + type: z.string(), +}) + +/** + * DatasetMetadataBuiltInFieldsResponse + */ +export const zDatasetMetadataBuiltInFieldsResponse = z.object({ + fields: z.array(zDatasetMetadataBuiltInFieldResponse), +}) + +/** + * DatasetMetadataListItemResponse + */ +export const zDatasetMetadataListItemResponse = z.object({ + count: z.int().optional().default(0), + id: z.string(), + name: z.string(), + type: z.string(), +}) + +/** + * DatasetMetadataListResponse + */ +export const zDatasetMetadataListResponse = z.object({ + built_in_field_enabled: z.boolean(), + doc_metadata: z.array(zDatasetMetadataListItemResponse), +}) + +/** + * DatasetMetadataResponse + */ +export const zDatasetMetadataResponse = z.object({ + id: z.string(), + name: z.string(), + type: z.string(), +}) + /** * DatasetPermissionEnum */ @@ -1160,7 +1209,7 @@ export const zPostDatasetsByDatasetIdDocumentsMetadataPath = z.object({ /** * Documents metadata updated successfully */ -export const zPostDatasetsByDatasetIdDocumentsMetadataResponse = zSimpleResultResponse +export const zPostDatasetsByDatasetIdDocumentsMetadataResponse = zDatasetMetadataActionResponse export const zPatchDatasetsByDatasetIdDocumentsStatusByActionPath = z.object({ action: z.string(), @@ -1453,7 +1502,7 @@ export const zGetDatasetsByDatasetIdMetadataPath = z.object({ /** * Metadata retrieved successfully */ -export const zGetDatasetsByDatasetIdMetadataResponse = z.record(z.string(), z.unknown()) +export const zGetDatasetsByDatasetIdMetadataResponse = zDatasetMetadataListResponse export const zPostDatasetsByDatasetIdMetadataBody = zMetadataArgs @@ -1464,7 +1513,7 @@ export const zPostDatasetsByDatasetIdMetadataPath = z.object({ /** * Metadata created successfully */ -export const zPostDatasetsByDatasetIdMetadataResponse = z.record(z.string(), z.unknown()) +export const zPostDatasetsByDatasetIdMetadataResponse = zDatasetMetadataResponse export const zGetDatasetsByDatasetIdMetadataBuiltInPath = z.object({ dataset_id: z.string(), @@ -1473,7 +1522,7 @@ export const zGetDatasetsByDatasetIdMetadataBuiltInPath = z.object({ /** * Built-in fields retrieved successfully */ -export const zGetDatasetsByDatasetIdMetadataBuiltInResponse = z.record(z.string(), z.unknown()) +export const zGetDatasetsByDatasetIdMetadataBuiltInResponse = zDatasetMetadataBuiltInFieldsResponse export const zPostDatasetsByDatasetIdMetadataBuiltInByActionPath = z.object({ action: z.string(), @@ -1483,7 +1532,8 @@ export const zPostDatasetsByDatasetIdMetadataBuiltInByActionPath = z.object({ /** * Action completed successfully */ -export const zPostDatasetsByDatasetIdMetadataBuiltInByActionResponse = zSimpleResultResponse +export const zPostDatasetsByDatasetIdMetadataBuiltInByActionResponse + = zDatasetMetadataActionResponse export const zDeleteDatasetsByDatasetIdMetadataByMetadataIdPath = z.object({ dataset_id: z.string(), @@ -1508,10 +1558,7 @@ export const zPatchDatasetsByDatasetIdMetadataByMetadataIdPath = z.object({ /** * Metadata updated successfully */ -export const zPatchDatasetsByDatasetIdMetadataByMetadataIdResponse = z.record( - z.string(), - z.unknown(), -) +export const zPatchDatasetsByDatasetIdMetadataByMetadataIdResponse = zDatasetMetadataResponse export const zGetDatasetsByDatasetIdPipelineDatasourcePluginsPath = z.object({ dataset_id: z.string(), diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index 396c42e187c..325454d4660 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -58,13 +58,15 @@ Utilities: ## Form contract -Dify UI's form primitives are a Base UI composition layer for native form semantics, field accessibility, and design-system styling. They are intentionally not a form state-management framework. See the upstream [Base UI Form], [Base UI Field], and [Base UI Fieldset] docs for the underlying component contracts. +Dify UI's form primitives are a Base UI composition layer for native form semantics, field accessibility, and design-system styling. They are intentionally not a form state-management framework. See the upstream [Base UI forms handbook], [Base UI Form], [Base UI Field], and [Base UI Fieldset] docs for the underlying component contracts. Use `Form` for the submit boundary. It renders a native `
`, preserves Enter-to-submit and submit-button behavior, and adds Base UI's `onFormSubmit`, `errors`, `actionsRef`, and `validationMode` APIs for structured values and consolidated field validation. Prefer it over a bare `` when the form is composed with Dify UI fields. -Use `FieldRoot` for each named field. A field must have a stable `name`, a visible `FieldLabel`, and either a `FieldControl` or another control that participates in the same Base UI field context. `FieldLabel`, `FieldDescription`, and `FieldError` provide the label and message relationships that screen readers need, while the Dify wrapper adds the default Form Input Set styling from the design system. +Use `FieldRoot` for each standalone named field. A field must have a stable `name`, a label relationship, and either a `FieldControl` or another control that participates in the same Base UI field context. Prefer a visible label for normal form rows; when the surrounding UI already supplies the visible text, use the matching label primitive visually hidden or put `aria-label` on the actual interactive control. `FieldDescription` and `FieldError` provide the message relationships that screen readers need, while the Dify wrapper adds the default Form Input Set styling from the design system. -Use `FieldsetRoot` and `FieldsetLegend` when one field is represented by a group of related controls, such as checkbox groups, radio groups, or multi-thumb sliders. Compose group controls with the Base UI pattern: +Choose the label primitive by the control semantics. Text-like inputs, input-based `Combobox` / `Autocomplete`, single `Checkbox` / `Radio`, `Switch`, and `NumberField` use `FieldLabel`. Trigger-based `Select` fields use `SelectLabel`; `Slider` fields use `SliderLabel`, with per-thumb `aria-label` only when the thumbs need distinct names. `SelectGroupLabel` and `AutocompleteGroupLabel` only label grouped options inside their popup content; they are not field labels. + +Use `FieldsetRoot` and `FieldsetLegend` when one field is represented by a group of related controls, such as checkbox groups, radio groups, multi-thumb sliders, or a section that combines several inputs. For checkbox and radio groups, wrap each option with `FieldItem` and give each option its own label: ```tsx @@ -82,9 +84,9 @@ Use `FieldsetRoot` and `FieldsetLegend` when one field is represented by a group `FieldsetRoot` provides the group semantics and legend relationship. It does not own the interactive state of the grouped control. Pass `disabled`, `value`, `defaultValue`, and change handlers to the actual group primitive (`CheckboxGroup`, radio group, slider root, etc.) instead of relying on the fieldset wrapper to manage them. -For complex business forms, keep state ownership outside these primitives. TanStack Form, zod, server validation, dialog reset behavior, and schema-driven rendering belong to the feature layer in `web/`; they should pass `name`, `invalid`, `dirty`, `touched`, `value`, `onValueChange`, and errors into these primitives rather than replacing the field semantics. +For complex business forms, keep state ownership outside these primitives. TanStack Form, zod, server validation, dialog reset behavior, and schema-driven rendering belong to the feature layer in `web/`; they should pass `name`, `invalid`, `dirty`, `touched`, `value`, `onValueChange`, and errors into these primitives rather than replacing the field semantics. In this repo, `web/app/components/base/form` is the TanStack/schema runtime adapter; `packages/dify-ui` remains the primitive layer. -Migration rule for `web/`: if a UI has a save/submit action, do not leave it as unrelated `Input` and `Button` pieces. Give it a real submit boundary with `Form` or a native ``, attach visible field names through `FieldLabel`, expose helper/error text through `FieldDescription` / `FieldError`, and keep non-submit buttons as `type="button"`. +Migration rule for `web/`: if a UI has a save/submit action, do not leave it as unrelated `Input` and `Button` pieces. Give it a real submit boundary with `Form` or a native ``, attach visible field names through the appropriate label primitive (`FieldLabel`, `SelectLabel`, `SliderLabel`, or `FieldsetLegend`), expose helper/error text through `FieldDescription` / `FieldError`, and keep non-submit buttons as `type="button"`. ## Tailwind CSS v4 integration @@ -180,5 +182,6 @@ See `[AGENTS.md](./AGENTS.md)` for: [Base UI Form]: https://base-ui.com/react/components/form [Base UI Portal]: https://base-ui.com/react/overview/quick-start#portals [Base UI docs index]: https://base-ui.com/llms.txt +[Base UI forms handbook]: https://base-ui.com/react/handbook/forms [Base UI]: https://base-ui.com/react [Overlay & portal contract]: #overlay--portal-contract diff --git a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx index 5c56dc4c07f..72ed0420333 100644 --- a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx @@ -6,12 +6,12 @@ import { AutocompleteContent, AutocompleteEmpty, AutocompleteGroup, + AutocompleteGroupLabel, AutocompleteInput, AutocompleteInputGroup, AutocompleteItem, AutocompleteItemIndicator, AutocompleteItemText, - AutocompleteLabel, AutocompleteList, AutocompleteSeparator, AutocompleteStatus, @@ -230,7 +230,7 @@ describe('Autocomplete wrappers', () => { - Resources + Resources Workflow diff --git a/packages/dify-ui/src/autocomplete/index.stories.tsx b/packages/dify-ui/src/autocomplete/index.stories.tsx index 71c7c6607dc..79f8983cd21 100644 --- a/packages/dify-ui/src/autocomplete/index.stories.tsx +++ b/packages/dify-ui/src/autocomplete/index.stories.tsx @@ -10,11 +10,11 @@ import { AutocompleteContent, AutocompleteEmpty, AutocompleteGroup, + AutocompleteGroupLabel, AutocompleteInput, AutocompleteInputGroup, AutocompleteItem, AutocompleteItemText, - AutocompleteLabel, AutocompleteList, AutocompleteSeparator, AutocompleteStatus, @@ -232,7 +232,7 @@ const GroupedSuggestionList = () => { {groups.map((group, groupIndex) => ( {groupIndex > 0 && } - {group.label} + {group.label} {(item: Suggestion) => ( @@ -252,7 +252,7 @@ const CommandPaletteList = () => { {groups.map((group, groupIndex) => ( {groupIndex > 0 && } - {group.label} + {group.label} {(item: Suggestion) => ( @@ -475,7 +475,7 @@ const FuzzyMatchingDemo = () => { } const meta = { - title: 'Base/UI/Autocomplete', + title: 'Base/Form/Autocomplete', component: Autocomplete, parameters: { layout: 'centered', @@ -599,7 +599,7 @@ export const LimitResults: Story = { export const CommandPalette: Story = { render: () => ( -
+
) diff --git a/packages/dify-ui/src/context-menu/index.tsx b/packages/dify-ui/src/context-menu/index.tsx index e33a94f03f2..422dff6b622 100644 --- a/packages/dify-ui/src/context-menu/index.tsx +++ b/packages/dify-ui/src/context-menu/index.tsx @@ -6,7 +6,6 @@ import type { Placement } from '../placement' import { ContextMenu as BaseContextMenu } from '@base-ui/react/context-menu' import { cn } from '../cn' import { - overlayBackdropClassName, overlayDestructiveClassName, overlayIndicatorClassName, overlayLabelClassName, @@ -24,6 +23,8 @@ export const ContextMenuTrigger = BaseContextMenu.Trigger export const ContextMenuSub = BaseContextMenu.SubmenuRoot export const ContextMenuGroup = BaseContextMenu.Group export const ContextMenuRadioGroup = BaseContextMenu.RadioGroup +export type ContextMenuActions = BaseContextMenu.Root.Actions +// Intentionally no public Backdrop export; Base UI handles context-menu modal dismissal internally. type ContextMenuContentProps = { children: ReactNode @@ -50,7 +51,6 @@ type ContextMenuPopupRenderProps = Required - {withBackdrop && ( - - )} @@ -56,7 +56,7 @@ export function FieldLabel({ }: FieldLabelProps) { return ( ) diff --git a/packages/dify-ui/src/text-control-variants.ts b/packages/dify-ui/src/form-control-shared.ts similarity index 92% rename from packages/dify-ui/src/text-control-variants.ts rename to packages/dify-ui/src/form-control-shared.ts index 2943c00cd79..d8454fce52b 100644 --- a/packages/dify-ui/src/text-control-variants.ts +++ b/packages/dify-ui/src/form-control-shared.ts @@ -1,5 +1,7 @@ import { cva } from 'class-variance-authority' +export const formLabelClassName = 'w-fit py-1 text-text-secondary system-sm-medium data-disabled:cursor-not-allowed' + export const textControlVariants = cva( [ 'w-full appearance-none border border-transparent bg-components-input-bg-normal text-components-input-text-filled caret-primary-600 outline-hidden transition-[background-color,border-color,box-shadow]', diff --git a/packages/dify-ui/src/input/index.tsx b/packages/dify-ui/src/input/index.tsx index cabac346c19..4f48f3cdae6 100644 --- a/packages/dify-ui/src/input/index.tsx +++ b/packages/dify-ui/src/input/index.tsx @@ -4,7 +4,7 @@ import type { Input as BaseInputNS } from '@base-ui/react/input' import type { VariantProps } from 'class-variance-authority' import { Input as BaseInput } from '@base-ui/react/input' import { cn } from '../cn' -import { textControlVariants } from '../text-control-variants' +import { textControlVariants } from '../form-control-shared' export type InputSize = NonNullable['size']> diff --git a/packages/dify-ui/src/radio-group/index.stories.tsx b/packages/dify-ui/src/radio-group/index.stories.tsx index c2c24518064..d28d9b06b00 100644 --- a/packages/dify-ui/src/radio-group/index.stories.tsx +++ b/packages/dify-ui/src/radio-group/index.stories.tsx @@ -17,7 +17,7 @@ const meta = { layout: 'centered', docs: { description: { - component: 'RadioGroup primitive built on Base UI. For normal form rows, compose FieldRoot, FieldsetRoot, FieldLabel, RadioGroup, and Radio. For option cards, make the card itself a RadioRoot with variant="unstyled" and render RadioControl inside it.', + component: 'RadioGroup primitive built on Base UI. For normal form rows, compose FieldRoot, FieldsetRoot, FieldLabel, RadioGroup, and Radio. For option cards, wrap each option in FieldItem and make the card itself a RadioRoot with variant="unstyled".', }, }, }, @@ -130,26 +130,27 @@ function OptionCardsDemo() { description: 'Write a prompt for this app and keep full control.', }, ].map(option => ( - } - className="w-full rounded-xl border border-components-option-card-option-border bg-components-option-card-option-bg p-4 text-left transition-colors hover:bg-state-base-hover data-checked:border-components-option-card-option-selected-border data-checked:bg-components-option-card-option-selected-bg" - > -
-
-
- {option.title} -
-
- {option.description} + + } + className="w-full rounded-xl border border-components-option-card-option-border bg-components-option-card-option-bg p-4 text-left transition-colors hover:bg-state-base-hover data-checked:border-components-option-card-option-selected-border data-checked:bg-components-option-card-option-selected-bg" + > +
+
+
+ {option.title} +
+
+ {option.description} +
+
-
- + + ))} @@ -161,7 +162,7 @@ export const OptionCards: Story = { parameters: { docs: { description: { - story: 'Use RadioRoot with variant="unstyled" when the entire option card is the radio. RadioControl renders the visual dot inside the card.', + story: 'Wrap each option card in FieldItem, then use RadioRoot with variant="unstyled" when the entire card is the radio. RadioControl renders the visual dot inside the card.', }, }, }, diff --git a/packages/dify-ui/src/select/__tests__/index.spec.tsx b/packages/dify-ui/src/select/__tests__/index.spec.tsx index 2fd4e23bdb2..ccdb13c61d5 100644 --- a/packages/dify-ui/src/select/__tests__/index.spec.tsx +++ b/packages/dify-ui/src/select/__tests__/index.spec.tsx @@ -1,5 +1,16 @@ import { render } from 'vitest-browser-react' -import { Select, SelectContent, SelectItem, SelectItemIndicator, SelectItemText, SelectTrigger, SelectValue } from '../index' +import { + Select, + SelectContent, + SelectGroup, + SelectGroupLabel, + SelectItem, + SelectItemIndicator, + SelectItemText, + SelectLabel, + SelectTrigger, + SelectValue, +} from '../index' const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement const renderWithSafeViewport = (ui: import('react').ReactNode) => render( @@ -84,6 +95,26 @@ describe('Select wrappers', () => { }) describe('SelectTrigger', () => { + it('should use SelectLabel as the trigger accessible name', async () => { + const screen = await renderWithSafeViewport( + , + ) + + await expect.element(screen.getByRole('combobox', { name: 'City' })).toBeInTheDocument() + await expect.element(screen.getByText('City')).toHaveClass('py-1', 'system-sm-medium', 'text-text-secondary') + }) + it('should forward native trigger props when trigger props are provided', async () => { const screen = await renderOpenSelect({ triggerProps: { @@ -179,6 +210,28 @@ describe('Select wrappers', () => { }) describe('SelectContent', () => { + it('should render SelectGroupLabel for grouped options without naming the trigger', async () => { + const screen = await renderWithSafeViewport( + , + ) + + await expect.element(screen.getByRole('combobox', { name: 'city select' })).toBeInTheDocument() + await expect.element(screen.getByText('Popular cities')).toHaveClass('custom-label') + }) + it('should use positioning attributes when placement is not provided', async () => { const screen = await renderOpenSelect() diff --git a/packages/dify-ui/src/select/index.stories.tsx b/packages/dify-ui/src/select/index.stories.tsx index 697266dcec2..40461be1999 100644 --- a/packages/dify-ui/src/select/index.stories.tsx +++ b/packages/dify-ui/src/select/index.stories.tsx @@ -4,6 +4,7 @@ import { Select, SelectContent, SelectGroup, + SelectGroupLabel, SelectItem, SelectItemIndicator, SelectItemText, @@ -62,6 +63,29 @@ export const Default: Story = { ), } +export const WithVisibleLabel: Story = { + render: () => ( +
+ +
+ ), +} + export const WithPlaceholder: Story = { render: () => (
@@ -123,7 +147,7 @@ export const WithGroupsAndSeparator: Story = { - OpenAI + OpenAI GPT-5 @@ -135,7 +159,7 @@ export const WithGroupsAndSeparator: Story = { - Anthropic + Anthropic Claude Opus @@ -147,7 +171,7 @@ export const WithGroupsAndSeparator: Story = { - Google + Google Gemini 2.5 diff --git a/packages/dify-ui/src/select/index.tsx b/packages/dify-ui/src/select/index.tsx index 0e2c53c3dc8..3dd145be98b 100644 --- a/packages/dify-ui/src/select/index.tsx +++ b/packages/dify-ui/src/select/index.tsx @@ -6,8 +6,10 @@ import type { Placement } from '../placement' import { Select as BaseSelect } from '@base-ui/react/select' import { cva } from 'class-variance-authority' import { cn } from '../cn' +import { formLabelClassName } from '../form-control-shared' import { overlayLabelClassName, + overlayPopupAnimationClassName, overlaySeparatorClassName, } from '../overlay-shared' import { parsePlacement } from '../placement' @@ -70,6 +72,18 @@ export function SelectTrigger({ export function SelectLabel({ className, ...props +}: BaseSelect.Label.Props) { + return ( + + ) +} + +export function SelectGroupLabel({ + className, + ...props }: BaseSelect.GroupLabel.Props) { return ( element as HTMLElement @@ -77,4 +85,21 @@ describe('Slider', () => { expect(screen.container.querySelector('script')).not.toBeInTheDocument() }) + + it('should expose SliderLabel for composed slider fields', async () => { + const screen = await render( + + Temperature + + + + + + + , + ) + + await expect.element(screen.getByRole('slider', { name: 'Temperature' })).toHaveAttribute('aria-valuenow', '50') + await expect.element(screen.getByText('Temperature')).toHaveClass('py-1', 'system-sm-medium', 'text-text-secondary') + }) }) diff --git a/packages/dify-ui/src/slider/index.stories.tsx b/packages/dify-ui/src/slider/index.stories.tsx index 844a9844064..11b22f0de34 100644 --- a/packages/dify-ui/src/slider/index.stories.tsx +++ b/packages/dify-ui/src/slider/index.stories.tsx @@ -1,7 +1,15 @@ import type { Meta, StoryObj } from '@storybook/react-vite' import type * as React from 'react' import { useState } from 'react' -import { Slider } from '.' +import { + Slider, + SliderControl, + SliderIndicator, + SliderLabel, + SliderRoot, + SliderThumb, + SliderTrack, +} from '.' const meta = { title: 'Base/Form/Slider', @@ -90,3 +98,17 @@ export const Disabled: Story = { disabled: true, }, } + +export const ComposedWithLabel: Story = { + render: () => ( + + Temperature + + + + + + + + ), +} diff --git a/packages/dify-ui/src/slider/index.tsx b/packages/dify-ui/src/slider/index.tsx index eafcecf7518..23719e5c0d3 100644 --- a/packages/dify-ui/src/slider/index.tsx +++ b/packages/dify-ui/src/slider/index.tsx @@ -2,9 +2,22 @@ import { Slider as BaseSlider } from '@base-ui/react/slider' import { cn } from '../cn' +import { formLabelClassName } from '../form-control-shared' export const SliderRoot = BaseSlider.Root +export function SliderLabel({ + className, + ...props +}: BaseSlider.Label.Props) { + return ( + + ) +} + type SliderRootProps = BaseSlider.Root.Props const sliderControlClassName = cn( diff --git a/web/__mocks__/base-ui-select.tsx b/web/__mocks__/base-ui-select.tsx index 76551644192..a695bebe14c 100644 --- a/web/__mocks__/base-ui-select.tsx +++ b/web/__mocks__/base-ui-select.tsx @@ -61,5 +61,6 @@ export const SelectItem = ({ export const SelectItemText = ({ children }: { children?: ReactNode }) => <>{children} export const SelectItemIndicator = ({ children }: { children?: ReactNode }) => <>{children} export const SelectGroup = ({ children }: { children?: ReactNode }) => <>{children} -export const SelectLabel = ({ children }: { children?: ReactNode }) => <>{children} +export const SelectLabel = () => null +export const SelectGroupLabel = ({ children }: { children?: ReactNode }) => <>{children} export const SelectSeparator = (props: React.HTMLAttributes) =>
diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 3113e367519..f1d96ad2c54 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -53,6 +53,9 @@ vi.mock('@/service/use-billing', () => ({ refetch: mockRefetch, }), useBindPartnerStackInfo: () => ({ mutateAsync: vi.fn() }), + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), })) vi.mock('@/service/use-education', () => ({ diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 707f1d690a3..58b531661e6 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -60,6 +60,9 @@ vi.mock('@/service/use-billing', () => ({ isFetching: false, refetch: vi.fn(), }), + useCurrentPlanVectorSpace: () => ({ + data: undefined, + }), })) // ─── Navigation mocks ─────────────────────────────────────────────────────── diff --git a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx index 3c0ae1befc8..51f4a16a255 100644 --- a/web/app/components/app/configuration/config/agent/agent-setting/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-setting/index.tsx @@ -2,6 +2,7 @@ import type { FC } from 'react' import type { AgentConfig } from '@/models/debug' import { Button } from '@langgenius/dify-ui/button' +import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset' import { Slider } from '@langgenius/dify-ui/slider' import { RiCloseLine } from '@remixicon/react' import { useClickAway } from 'ahooks' @@ -34,6 +35,7 @@ const AgentSetting: FC = ({ const [tempPayload, setTempPayload] = useState(payload) const ref = useRef(null) const [mounted, setMounted] = useState(false) + const maximumIterationsLabel = t('agent.setting.maximumIterations.name', { ns: 'appDebug' }) useClickAway(() => { if (mounted) @@ -96,10 +98,11 @@ const AgentSetting: FC = ({ icon={ } - name={t('agent.setting.maximumIterations.name', { ns: 'appDebug' })} + name={maximumIterationsLabel} description={t('agent.setting.maximumIterations.description', { ns: 'appDebug' })} > -
+ + {maximumIterationsLabel} = ({ max_iteration: value, }) }} - aria-label={t('agent.setting.maximumIterations.name', { ns: 'appDebug' })} + aria-label={maximumIterationsLabel} /> = ({ }) }} /> -
+ {!isFunctionCall && ( diff --git a/web/app/components/base/features/new-feature-panel/follow-up-setting-modal.tsx b/web/app/components/base/features/new-feature-panel/follow-up-setting-modal.tsx index c99f20f842d..e549c6ce475 100644 --- a/web/app/components/base/features/new-feature-panel/follow-up-setting-modal.tsx +++ b/web/app/components/base/features/new-feature-panel/follow-up-setting-modal.tsx @@ -8,7 +8,7 @@ import type { import { Button } from '@langgenius/dify-ui/button' import { cn } from '@langgenius/dify-ui/cn' import { Dialog, DialogCloseButton, DialogContent, DialogTitle } from '@langgenius/dify-ui/dialog' -import { FieldRoot } from '@langgenius/dify-ui/field' +import { FieldItem, FieldRoot } from '@langgenius/dify-ui/field' import { FieldsetLegend, FieldsetRoot } from '@langgenius/dify-ui/fieldset' import { RadioControl, RadioRoot } from '@langgenius/dify-ui/radio' import { RadioGroup } from '@langgenius/dify-ui/radio-group' @@ -161,70 +161,74 @@ const FollowUpSettingModal = ({ {t('feature.suggestedQuestionsAfterAnswer.modal.promptLabel', { ns: 'appDebug' })} - } - className={cn( - 'w-full rounded-xl border p-4 text-left transition-colors', - promptMode === PROMPT_MODE.default - ? 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg' - : 'border-components-option-card-option-border bg-components-option-card-option-bg hover:bg-state-base-hover', - )} - > -
-
-
- {t('feature.suggestedQuestionsAfterAnswer.modal.defaultPromptOption', { ns: 'appDebug' })} -
-
- {t('feature.suggestedQuestionsAfterAnswer.modal.defaultPromptOptionDescription', { ns: 'appDebug' })} + + } + className={cn( + 'w-full rounded-xl border p-4 text-left transition-colors', + promptMode === PROMPT_MODE.default + ? 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg' + : 'border-components-option-card-option-border bg-components-option-card-option-bg hover:bg-state-base-hover', + )} + > +
+
+
+ {t('feature.suggestedQuestionsAfterAnswer.modal.defaultPromptOption', { ns: 'appDebug' })} +
+
+ {t('feature.suggestedQuestionsAfterAnswer.modal.defaultPromptOptionDescription', { ns: 'appDebug' })} +
+
-
- {promptMode === PROMPT_MODE.default && ( -
-
- {DEFAULT_FOLLOW_UP_PROMPT} + {promptMode === PROMPT_MODE.default && ( +
+
+ {DEFAULT_FOLLOW_UP_PROMPT} +
+ )} + + + + } + className={cn( + 'w-full rounded-xl border p-4 text-left transition-colors', + promptMode === PROMPT_MODE.custom + ? 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg' + : 'border-components-option-card-option-border bg-components-option-card-option-bg hover:bg-state-base-hover', + )} + > +
+
+
+ {t('feature.suggestedQuestionsAfterAnswer.modal.customPromptOption', { ns: 'appDebug' })} +
+
+ {t('feature.suggestedQuestionsAfterAnswer.modal.customPromptOptionDescription', { ns: 'appDebug' })} +
+
+
- )} -
- } - className={cn( - 'w-full rounded-xl border p-4 text-left transition-colors', - promptMode === PROMPT_MODE.custom - ? 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg' - : 'border-components-option-card-option-border bg-components-option-card-option-bg hover:bg-state-base-hover', - )} - > -
-
-
- {t('feature.suggestedQuestionsAfterAnswer.modal.customPromptOption', { ns: 'appDebug' })} -
-
- {t('feature.suggestedQuestionsAfterAnswer.modal.customPromptOptionDescription', { ns: 'appDebug' })} -
-
-
- {promptMode === PROMPT_MODE.custom && ( -